Commit d44e291c authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Doc] Convert docs from rst format to Markdown format. (#82)

* [CI] Clean up target repository before publishing documentation.

* [Doc] Convert docs from rst format to Markdown format.
parent d416bc40
......@@ -31,15 +31,26 @@ extensions = [
"sphinx_reredirects",
"sphinx.ext.mathjax",
"sphinx.ext.autosummary",
"myst_parser",
]
source_suffix = {
'.rst': 'restructuredtext',
'.md': 'markdown',
}
myst_enable_extensions = [
"colon_fence",
"deflist",
]
redirects = {"get_started/try_out": "../index.html#getting-started"}
source_suffix = [".rst"]
source_suffix = [".md"]
language = "en"
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "README.md"]
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"
......
======================================================
General Matrix-Matrix Multiplication with Tile Library
======================================================
# General Matrix-Matrix Multiplication with Tile Library
.. raw:: html
<div style="text-align: left;">
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/LeiWang1999">Lei Wang</a>
</div>
</div>
.. warning::
:::{warning}
:class: myclass1 myclass2
:name: a-tip-reference
This document is still **experimental** and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
TileLang is a domain-specific language (DSL) designed for writing high-performance GPU kernels. It provides three main levels of abstraction:
* **Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM.
* **Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Tritons programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc.
* **Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton's programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc.
* **Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc.
.. _fig-overview:
.. figure:: ../_static/img/overview.png
:align: center
:width: 50%
:alt: Overview
```{figure} ../_static/img/overview.png
:width: 50%
:alt: Overview
:align: center
High-level overview of the TileLang compilation flow.
Figure 1: High-level overview of the TileLang compilation flow.
```
In this tutorial, we introduce Level 2 with a matrix multiplication example in TileLang. We will walk through how to allocate shared memory, set up thread blocks, perform parallel copying, pipeline the computation, and invoke the tile-level GEMM intrinsic. We will then show how to compile and run the kernel in Python, comparing results and measuring performance.
----------------------------
Why Another GPU DSL?
----------------------------
## Why Another GPU DSL?
TileLang emerged from the need for a DSL that:
......@@ -44,18 +40,15 @@ TileLang emerged from the need for a DSL that:
While Level 1 in TileLang can be very comfortable for general users—since it requires no scheduling or hardware-specific knowledge—it can incur longer auto-tuning times and may not handle some complex kernel fusion patterns (e.g., Flash Attention) as easily. Level 3 gives you full control but demands more effort, similar to writing raw CUDA/HIP kernels. Level 2 thus strikes a balance for users who want to write portable and reasonably concise code while expressing important architectural hints.
----------------------------
Matrix Multiplication Example
----------------------------
## Matrix Multiplication Example
In this section, we demonstrate how to write a 2D-tiled matrix multiplication kernel at Level 2 in TileLang.
```{figure} ../_static/img/MatmulExample.png
:alt: Matmul Example
:align: center
.. figure:: ../_static/img/MatmulExample.png
:align: center
:alt: Matmul Example
```
Basic Structure
^^^^^^^^^^^^^^^
### Basic Structure
Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplication. It uses:
......@@ -66,13 +59,12 @@ Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplicatio
* **`T.Parallel(...)`** to parallelize data copy loops.
* **`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs).
.. code-block:: python
import tilelang
import tilelang.language as T
from tilelang.intrinsics import make_mma_swizzle_layout
```python
import tilelang
import tilelang.language as T
from tilelang.intrinsics import make_mma_swizzle_layout
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
......@@ -113,107 +105,112 @@ Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplicatio
return main
# 1. Create the TileLang function
func = matmul(1024, 1024, 1024, 128, 128, 32)
# 1. Create the TileLang function
func = matmul(1024, 1024, 1024, 128, 128, 32)
# 2. JIT-compile the kernel for NVIDIA GPU
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
# 2. JIT-compile the kernel for NVIDIA GPU
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
import torch
import torch
# 3. Prepare input tensors in PyTorch
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# 3. Prepare input tensors in PyTorch
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# 4. Invoke the JIT-compiled kernel
c = jit_kernel(a, b)
ref_c = a @ b
# 4. Invoke the JIT-compiled kernel
c = jit_kernel(a, b)
ref_c = a @ b
# 5. Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 5. Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 6. Inspect generated CUDA code (optional)
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)
# 6. Inspect generated CUDA code (optional)
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)
# 7. Profile performance
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
# 7. Profile performance
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
```
Key Concepts
^^^^^^^^^^^^
### Key Concepts
1. **Kernel Context**:
.. code-block:: python
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
```python
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...
```
- This sets up the block grid dimensions based on :math:`\lceil N / block\_N \rceil` and :math:`\lceil M / block\_M \rceil`.
- `threads=128` specifies that each thread block uses 128 threads. The compiler will infer how loops map to these threads.
- This sets up the block grid dimensions based on N/block_N and M/block_M.
- `threads=128` specifies that each thread block uses 128 threads. The compiler will infer how loops map to these threads.
.. figure:: ../_static/img/Parallel.png
:align: center
:alt: Parallel
2. **Shared & Fragment Memory**:
```{figure} ../_static/img/Parallel.png
:alt: Parallel
:align: center
.. code-block:: python
```
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
- `T.alloc_shared` allocates shared memory across the entire thread block.
- `T.alloc_fragment` allocates register space for local accumulation. Though it is written as `(block_M, block_N)`, the compiler’s layout inference assigns slices of this space to each thread.
2. **Shared & Fragment Memory**:
3. **Software Pipelining**:
```python
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
```
.. code-block:: python
- `T.alloc_shared` allocates shared memory across the entire thread block.
- `T.alloc_fragment` allocates register space for local accumulation. Though it is written as `(block_M, block_N)`, the compiler’s layout inference assigns slices of this space to each thread.
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
3. **Software Pipelining**:
```python
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
...
```
- `T.Pipelined` automatically arranges asynchronous copy and compute instructions to overlap memory operations with arithmetic.
- The argument `num_stages=3` indicates the pipeline depth.
- `T.Pipelined` automatically arranges asynchronous copy and compute instructions to overlap memory operations with arithmetic.
- The argument `num_stages=3` indicates the pipeline depth.
.. figure:: ../_static/img/software_pipeline_inference.png
:align: center
:alt: Software Pipeline Inference
```{figure} ../_static/img/software_pipeline_inference.png
:alt: Software Pipeline Inference
:align: center
4. **Parallel Copy**:
```
.. code-block:: python
for k, j in T.Parallel(block_K, block_N):
4. **Parallel Copy**:
```python
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
```
- `T.Parallel` marks the loop for thread-level parallelization.
- The compiler will map these loops to the available threads in the block.
- `T.Parallel` marks the loop for thread-level parallelization.
- The compiler will map these loops to the available threads in the block.
5. **Tile-Level GEMM**:
.. code-block:: python
```python
T.gemm(A_shared, B_shared, C_local)
```
T.gemm(A_shared, B_shared, C_local)
- A single call that performs a tile-level matrix multiplication using the specified buffers.
- Under the hood, for NVIDIA targets, it can use CUTLASS/Cute or WMMA instructions. On AMD GPUs, TileLang uses a separate HIP or composable kernel approach.
- A single call that performs a tile-level matrix multiplication using the specified buffers.
- Under the hood, for NVIDIA targets, it can use CUTLASS/Cute or WMMA instructions. On AMD GPUs, TileLang uses a separate HIP or composable kernel approach.
6. **Copying Back Results**:
.. code-block:: python
```python
T.copy(C_local, C[by * block_M, bx * block_N])
```
T.copy(C_local, C[by * block_M, bx * block_N])
- After computation, data in the local register fragment is written back to global memory.
- After computation, data in the local register fragment is written back to global memory.
----------------------------
Comparison with Other DSLs
----------------------------
## Comparison with Other DSLs
TileLang Level 2 is conceptually similar to Triton in that the user can control tiling and parallelization, while letting the compiler handle many low-level details. However, TileLang also:
......@@ -221,13 +218,13 @@ TileLang Level 2 is conceptually similar to Triton in that the user can control
- Supports a flexible pipeline pass (`T.Pipelined`) that can be automatically inferred or manually defined.
- Enables mixing different levels in a single program—for example, you can write some parts of your kernel in Level 3 (thread primitives) for fine-grained PTX/inline-assembly and keep the rest in Level 2.
-----------------------------------
Performance on Different Platforms
-----------------------------------
## Performance on Different Platforms
```{figure} ../_static/img/op_benchmark_consistent_gemm_fp16.png
:alt: Performance on Different Platforms
:align: center
.. figure:: ../_static/img/op_benchmark_consistent_gemm_fp16.png
:align: center
:alt: Performance on Different Platforms
```
When appropriately tuned (e.g., by using an auto-tuner), TileLang achieves performance comparable to or better than vendor libraries and Triton on various GPUs. In internal benchmarks, for an FP16 matrix multiply (e.g., 4090, A100, H100, MI300X), TileLang has shown:
......@@ -239,9 +236,7 @@ When appropriately tuned (e.g., by using an auto-tuner), TileLang achieves perfo
These measurements will vary based on tile sizes, pipeline stages, and the hardware’s capabilities.
----------------------------
Conclusion
----------------------------
## Conclusion
This tutorial demonstrated a Level 2 TileLang kernel for matrix multiplication. With just a few lines of code:
......@@ -255,12 +250,10 @@ By balancing high-level abstractions (like `T.copy`, `T.Pipelined`, `T.gemm`) wi
For more advanced usage—including partial lowering, explicitly controlling thread primitives, or using inline assembly—you can explore Level 3. Meanwhile, for purely functional expressions and high-level scheduling auto-tuning, consider Level 1.
----------------------------
Further Resources
----------------------------
## Further Resources
* `TileLang GitHub <https://github.com/tile-ai/tilelang>`_
* `BitBLAS <https://github.com/tile-ai/bitblas>`_
* `Triton <https://github.com/openai/triton>`_
* `Cutlass <https://github.com/NVIDIA/cutlass>`_
* `PyCUDA <https://documen.tician.de/pycuda/>`_
* [TileLang GitHub](https://github.com/tile-ai/tilelang)
* [BitBLAS](https://github.com/tile-ai/bitblas)
* [Triton](https://github.com/openai/triton)
* [Cutlass](https://github.com/NVIDIA/cutlass)
* [PyCUDA](https://documen.tician.de/pycuda/)
Installation Guide
==================
(install)=
Installing with pip
-------------------
# Installation Guide
## Installing with pip
**Prerequisites for installation via wheel or PyPI:**
- **Operating System**: Ubuntu 20.04 or later
- **Python Version**: >= 3.8
- **CUDA Version**: >= 11.0
The easiest way to install TileLang is directly from PyPI using pip. To install the latest version, run the following command in your terminal:
.. code:: bash
pip install tilelang
```bash
pip install tilelang
```
Alternatively, you may choose to install TileLang using prebuilt packages available on the Release Page:
.. code:: bash
pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl
```bash
pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl
```
To install the latest version of TileLang from the GitHub repository, you can run the following command:
.. code:: bash
pip install git+https://github.com/tile-ai/tilelang.git
```bash
pip install git+https://github.com/tile-ai/tilelang.git
```
After installing TileLang, you can verify the installation by running:
.. code:: bash
python -c "import tilelang; print(tilelang.__version__)"
```bash
python -c "import tilelang; print(tilelang.__version__)"
```
Building from Source
--------------------
## Building from Source
**Prerequisites for building from source:**
- **Operating System**: Linux
- **Python Version**: >= 3.7
- **CUDA Version**: >= 10.0
We recommend using a Docker container with the necessary dependencies to build TileLang from source. You can use the following command to run a Docker container with the required dependencies:
.. code:: bash
docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.01-py3
```bash
docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.01-py3
```
To build and install TileLang directly from source, follow these steps. This process requires certain pre-requisites from Apache TVM, which can be installed on Ubuntu/Debian-based systems using the following commands:
.. code:: bash
sudo apt-get update
sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
```bash
sudo apt-get update
sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
```
After installing the prerequisites, you can clone the TileLang repository and install it using pip:
.. code:: bash
git clone --recursive https://github.com/tile-ai/tilelang.git
cd tileLang
pip install . # Please be patient, this may take some time.
```bash
git clone --recursive https://github.com/tile-ai/tilelang.git
cd tileLang
pip install . # Please be patient, this may take some time.
```
If you want to install TileLang in development mode, you can run the following command:
.. code:: bash
pip install -e .
```bash
pip install -e .
```
We currently provide three methods to install **TileLang**:
1. `Install from Source (using your own TVM installation)`_
2. `Install from Source (using the bundled TVM submodule)`_
3. `Install Using the Provided Script`_
1. [Install from Source (using your own TVM installation)](#install-method-1)
2. [Install from Source (using the bundled TVM submodule)](#install-method-2)
3. [Install Using the Provided Script](#install-method-3)
.. _Install from Source (using your own TVM installation): #method-1-install-from-source-using-your-own-tvm-installation
.. _Install from Source (using the bundled TVM submodule): #method-2-install-from-source-using-the-bundled-tvm-submodule
.. _Install Using the Provided Script: #method-3-install-using-the-provided-script
(install-method-1)=
Method 1: Install from Source (Using Your Own TVM Installation)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
### Method 1: Install from Source (Using Your Own TVM Installation)
If you already have a compatible TVM installation, follow these steps:
1. **Clone the Repository**:
.. code:: bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```
**Note**: Use the `--recursive` flag to include necessary submodules.
**Note**: Use the `--recursive` flag to include necessary submodules.
2. **Configure Build Options**:
Create a build directory and specify your existing TVM path:
Create a build directory and specify your existing TVM path:
.. code:: bash
mkdir build
cd build
cmake .. -DTVM_PREBUILD_PATH=/your/path/to/tvm/build # e.g., /workspace/tvm/build
make -j 16
```bash
mkdir build
cd build
cmake .. -DTVM_PREBUILD_PATH=/your/path/to/tvm/build # e.g., /workspace/tvm/build
make -j 16
```
3. **Set Environment Variables**:
Update `PYTHONPATH` to include the `tile-lang` Python module:
Update `PYTHONPATH` to include the `tile-lang` Python module:
.. code:: bash
```bash
export PYTHONPATH=/your/path/to/tilelang/:$PYTHONPATH
# TVM_IMPORT_PYTHON_PATH is used by 3rd-party frameworks to import TVM
export TVM_IMPORT_PYTHON_PATH=/your/path/to/tvm/python
```
export PYTHONPATH=/your/path/to/tilelang/:$PYTHONPATH
# TVM_IMPORT_PYTHON_PATH is used by 3rd-party frameworks to import TVM
export TVM_IMPORT_PYTHON_PATH=/your/path/to/tvm/python
(install-method-2)=
Method 2: Install from Source (Using the Bundled TVM Submodule)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
### Method 2: Install from Source (Using the Bundled TVM Submodule)
If you prefer to use the built-in TVM version, follow these instructions:
1. **Clone the Repository**:
.. code:: bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```
**Note**: Ensure the `--recursive` flag is included to fetch submodules.
**Note**: Ensure the `--recursive` flag is included to fetch submodules.
2. **Configure Build Options**:
Copy the configuration file and enable the desired backends (e.g., LLVM and CUDA):
.. code:: bash
Copy the configuration file and enable the desired backends (e.g., LLVM and CUDA):
mkdir build
cp 3rdparty/tvm/cmake/config.cmake build
cd build
echo "set(USE_LLVM ON)" >> config.cmake
echo "set(USE_CUDA ON)" >> config.cmake
# or echo "set(USE_ROCM ON)" >> config.cmake to enable ROCm runtime
cmake ..
make -j 16
```bash
mkdir build
cp 3rdparty/tvm/cmake/config.cmake build
cd build
echo "set(USE_LLVM ON)" >> config.cmake
echo "set(USE_CUDA ON)" >> config.cmake
# or echo "set(USE_ROCM ON)" >> config.cmake to enable ROCm runtime
cmake ..
make -j 16
```
The build outputs (e.g., `libtilelang.so`, `libtvm.so`, `libtvm_runtime.so`) will be generated in the `build` directory.
The build outputs (e.g., `libtilelang.so`, `libtvm.so`, `libtvm_runtime.so`) will be generated in the `build` directory.
3. **Set Environment Variables**:
Ensure the `tile-lang` Python package is in your `PYTHONPATH`:
Ensure the `tile-lang` Python package is in your `PYTHONPATH`:
.. code:: bash
```bash
export PYTHONPATH=/your/path/to/tilelang/:$PYTHONPATH
```
export PYTHONPATH=/your/path/to/tilelang/:$PYTHONPATH
(install-method-3)=
Method 3: Install Using the Provided Script
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
### Method 3: Install Using the Provided Script
For a simplified installation, use the provided script:
1. **Clone the Repository**:
.. code:: bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```
2. **Run the Installation Script**:
.. code:: bash
bash install_cuda.sh
# or bash `install_amd.sh` if you want to enable ROCm runtime
```bash
bash install_cuda.sh
# or bash `install_amd.sh` if you want to enable ROCm runtime
The Tile Language: A Brief Introduction
===============================
# The Tile Language: A Brief Introduction
.. _sec-overview:
Programming Interface
---------------------
## Programming Interface
The figure below depicts how **TileLang** programs are progressively lowered from a high-level description to hardware-specific executables. We provide three different programming interfaces—targeted at **Beginner**, **Developer**, and **Expert** users—that each reside at different levels in this lowering pipeline. The **Tile Language** also allows mixing these interfaces within the same kernel, enabling users to work at whichever level of abstraction best suits their needs.
.. _fig-overview:
.. figure:: ../_static/img/overview.png
:align: center
:width: 50%
:alt: Overview
```{figure} ../_static/img/overview.png
:width: 50%
:alt: Overview
:align: center
High-level overview of the TileLang compilation flow.
Figure 1: High-level overview of the TileLang compilation flow.
```
Programming Interfaces
----------------------
## Programming Interfaces
1. **Beginner Level (Hardware-Unaware)**
- Intended for users who need to write code that is independent of specific hardware details.
......@@ -35,8 +29,7 @@ Programming Interfaces
- Offers direct access to **thread primitives** and other low-level constructs, allowing for fine-grained control of performance-critical kernels.
- This level grants maximum flexibility for specialized optimizations tailored to specific GPU or multi-core architectures.
Compilation Flow
----------------
## Compilation Flow
1. **Tile Program**
A high-level specification of the computation. Depending on the user’s expertise, they may write a purely hardware-unaware tile program or incorporate constructs from the Tile Library or thread primitives.
......@@ -56,78 +49,43 @@ Compilation Flow
6. **Hardware-Specific Executable/Runtime**
Finally, the generated source is compiled into hardware-specific executables, ready to run on the corresponding devices. The pipeline supports multiple GPU backends and can be extended to additional architectures.
## Tile-based Programming Model
.. _sec-tile_based_programming_model:
Tile-based Programming Model
----------------------------
Figure :ref:`fig-matmul_example` provides a concise matrix multiplication (GEMM) example in ``TileLang``,
[Figure 2](#fig-overview-gemm) provides a concise matrix multiplication (GEMM) example in ``TileLang``,
illustrating how developers can employ high-level constructs such as tiles, memory placement, pipelining,
and operator calls to manage data movement and computation with fine-grained control.
In particular, this snippet (Figure :ref:`fig-matmul_example` (a)) demonstrates how multi-level tiling
In particular, this snippet ([Figure 2](#fig-overview-gemm) (a)) demonstrates how multi-level tiling
leverages different memory hierarchies (global, shared, and registers) to optimize bandwidth utilization
and reduce latency.
Overall, Figure :ref:`fig-matmul_example` (b) showcases how the Python-like syntax of ``TileLang``
Overall, [Figure 2](#fig-overview-gemm) (b) showcases how the Python-like syntax of ``TileLang``
allows developers to reason about performance-critical optimizations within a user-friendly programming model.
.. _fig-matmul_example:
```{figure} ../_static/img/MatmulExample.png
:align: center
:width: 100%
:alt: GEMM with Multi-Level Tiling on GPUs
:name: fig-overview-gemm
.. figure:: ../_static/img/MatmulExample.png
:align: center
:width: 100%
:alt: GEMM with Multi-Level Tiling on GPUs
Figure 2: Optimizing GEMM with Multi-Level Tiling on GPUs via ``TileLang``.
```
Optimizing GEMM with Multi-Level Tiling on GPUs via ``TileLang``.
Tile declarations
~~~~~~~~~~~~~~~~~
At the heart of our approach is the notion of *tiles* as first-class objects in the programming model.
A tile represents a shaped portion of data, which can be owned and manipulated by a warp, thread block,
or equivalent parallel unit.
In the ``Matmul`` example, the ``A`` and ``B`` buffers are read in tiled chunks (determined by ``block_M``,
``block_N``, ``block_K``) inside the kernel loop.
With ``T.Kernel``, ``TileLang`` defines the execution context, which includes the thread block index (``bx``
and ``by``) and the number of threads.
These contexts can help compute the index for each thread block and make it easier for ``TileLang``
to automatically infer and optimize memory access and computation.
Additionally, these contexts allow users to manually control the behavior of each independent thread within
a thread block.
Explicit Hardware Memory Allocation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A hallmark of ``TileLang`` is the ability to explicitly place these tile buffers in the hardware memory hierarchy.
Rather than leaving it to a compiler's opaque optimization passes, ``TileLang`` exposes user-facing intrinsics
that map directly to physical memory spaces or accelerator-specific constructs.
In particular:
- ``T.alloc_shared``: Allocates memory in a fast, on-chip storage space, which corresponds to shared memory on NVIDIA GPUs.
Shared memory is ideal for caching intermediate data during computations, as it is significantly faster than global memory
and allows for efficient data sharing between threads in the same thread block.
For example, in matrix multiplication, tiles of matrices can be loaded into shared memory
to reduce global memory bandwidth demands and improve performance.
- ``T.alloc_fragment``: Allocates accumulators in fragment memory, which corresponds to register files on NVIDIA GPUs.
By keeping inputs and partial sums in registers or hardware-level caches, latency is further minimized.
Note that in this tile program, each tile allocates the same local buffers as shared memory,
which might seem counterintuitive, as shared memory is generally faster but more abundant,
whereas register file space is limited.
This is because the allocation here refers to the register files for an entire thread block.
``TileLang`` uses a Layout Inference Pass during compilation to derive a Layout object ``T.Fragment``,
which determines how to allocate the corresponding register files for each thread.
This process will be discussed in detail in subsequent sections.
Data transfer between global memory and hardware-specific memory can be managed using ``T.copy``.
Furthermore, hardware-specific buffers can be initialized using ``T.clear`` or ``T.fill``.
For data assignments, operations can also be performed in parallel using ``T.Parallel``,
as demonstrated in Layout Inference Pass in the following sections.
.. _fig-layout_inference:
.. figure:: ../_static/img/LayoutInference.png
### Tile declarations
At the heart of our approach is the notion of *tiles* as first-class objects in the programming model. A tile represents a shaped portion of data, which can be owned and manipulated by a warp, thread block, or equivalent parallel unit. In the `Matmul` example, the `A` and `B` buffers are read in tiled chunks (determined by `block_M`, `block_N`, `block_K`) inside the kernel loop. With `T.Kernel`, `TileLang` defines the execution context, which includes the thread block index (`bx` and `by`) and the number of threads. These contexts can help compute the index for each thread block and make it easier for `TileLang` to automatically infer and optimize memory access and computation. Additionally, these contexts allow users to manually control the behavior of each independent thread within a thread block.
### Explicit Hardware Memory Allocation
A hallmark of `TileLang` is the ability to explicitly place these tile buffers in the hardware memory hierarchy. Rather than leaving it to a compiler's opaque optimization passes, `TileLang` exposes user-facing intrinsics that map directly to physical memory spaces or accelerator-specific constructs. In particular:
- `T.alloc_shared`: Allocates memory in a fast, on-chip storage space, which corresponds to shared memory on NVIDIA GPUs. Shared memory is ideal for caching intermediate data during computations, as it is significantly faster than global memory and allows for efficient data sharing between threads in the same thread block. For example, in matrix multiplication, tiles of matrices can be loaded into shared memory to reduce global memory bandwidth demands and improve performance.
- `T.alloc_fragment`: Allocates accumulators in fragment memory, which corresponds to register files on NVIDIA GPUs. By keeping inputs and partial sums in registers or hardware-level caches, latency is further minimized. Note that in this tile program, each tile allocates the same local buffers as shared memory, which might seem counterintuitive, as shared memory is generally faster but more abundant, whereas register file space is limited. This is because the allocation here refers to the register files for an entire thread block. `TileLang` uses a Layout Inference Pass during compilation to derive a Layout object `T.Fragment`, which determines how to allocate the corresponding register files for each thread. This process will be discussed in detail in subsequent sections.
Data transfer between global memory and hardware-specific memory can be managed using `T.copy`. Furthermore, hardware-specific buffers can be initialized using `T.clear` or `T.fill`. For data assignments, operations can also be performed in parallel using `T.Parallel`, as demonstrated in Layout Inference Pass in the following sections.
```{figure} ../_static/img/LayoutInference.png
:align: center
:width: 100%
:alt: GEMM with Multi-Level Tiling on GPUs
```
# 👋 Welcome to Tile Language
[GitHub](https://github.com/tile-ai/tilelang)
Tile Language (tile-lang) is a concise domain-specific language designed to streamline
the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention).
By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM,
tile-lang allows developers to focus on productivity without sacrificing the
low-level optimizations necessary for state-of-the-art performance.
:::{toctree}
:maxdepth: 2
:caption: GET STARTED
get_started/Installation
get_started/overview
:::
:::{toctree}
:maxdepth: 1
:caption: TUTORIALS
tutorials/writing_kernels_with_tilelibrary
tutorials/writing_kernels_with_thread_primitives
tutorials/annotate_memory_layout
tutorials/debug_tools_for_tilelang
tutorials/auto_tuning
tutorials/jit_compilation
tutorials/pipelining_computations_and_data_movements
:::
:::{toctree}
:maxdepth: 1
:caption: DEEP LEARNING OPERATORS
deeplearning_operators/elementwise
deeplearning_operators/gemv
deeplearning_operators/matmul
deeplearning_operators/matmul_dequant
deeplearning_operators/flash_attention
deeplearning_operators/flash_linear_attention
deeplearning_operators/convolution
deeplearning_operators/tmac_gpu
:::
:::{toctree}
:maxdepth: 2
:caption: LANGUAGE REFERENCE
language_ref/ast
language_ref/primitives
language_ref/tilelibrary
:::
:::{toctree}
:maxdepth: 1
:caption: Privacy
privacy
:::
\ No newline at end of file
👋 Welcome to Tile Language
===========================
`GitHub <https://github.com/tile-ai/tilelang>`_
Tile Language (tile-lang) is a concise domain-specific language designed to streamline
the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention).
By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM,
tile-lang allows developers to focus on productivity without sacrificing the
low-level optimizations necessary for state-of-the-art performance.
.. toctree::
:maxdepth: 2
:caption: GET STARTED
get_started/Installation.rst
get_started/overview.rst
.. toctree::
:maxdepth: 1
:caption: TUTORIALS
tutorials/writing_kernels_with_tilelibrary.rst
tutorials/writint_kernels_with_thread_primitives.rst
tutorials/annotate_memory_layout.rst
tutorials/debug_tools_for_tilelang.rst
tutorials/auto_tuning.rst
tutorials/jit_compilation.rst
tutorials/pipelining_computations_and_data_movements.rst
.. toctree::
:maxdepth: 1
:caption: DEEP LEARNING OPERATORS
deeplearning_operators/elementwise.rst
deeplearning_operators/gemv.rst
deeplearning_operators/matmul.rst
deeplearning_operators/matmul_dequant.rst
deeplearning_operators/flash_attention.rst
deeplearning_operators/flash_linear_attention.rst
deeplearning_operators/convolution.rst
deeplearning_operators/tmac_gpu.rst
.. toctree::
:maxdepth: 2
:caption: LANGUAGE REFERENCE
language_ref/ast.rst
language_ref/primitives.rst
language_ref/tilelibrary.rst
.. toctree::
:maxdepth: 1
:caption: Privacy
privacy.rst
Privacy
====================
# Privacy
All data stays in users' device and is not collected by the app.
......@@ -8,3 +8,4 @@ sphinxcontrib-napoleon==0.7
sphinxcontrib_httpdomain==1.8.1
furo
uvicorn
myst-parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment