Commit 17386d7d authored by yeh-sudo's avatar yeh-sudo Committed by LeiWang1999
Browse files

[Doc] Fix typo and heading level in GEMV tutorial (#337)

This pull request includes a change to the `gemv.md` file. The changes
add heading level to title in the document to make the heading level
right.
parent 89725f7f
General Matrix-Vector Multiplication (GEMV)
# General Matrix-Vector Multiplication (GEMV)
===========================================
<div style="text-align: left;">
......@@ -16,7 +16,7 @@ Example code can be found at `examples/gemv/example_gemv.py`.
General matrix-vector multiplication (GEMV) can be viewed as a specialized case of general matrix-matrix multiplication (GEMM). It plays a critical role in deep learning, especially during the inference phase of large language models. In this tutorial, we will optimize GEMV from a thread-level perspective step by step using `TileLang`.
# Triton implementation
## Triton Implementation
When implementing a GEMV kernel, you might start with a high-level approach using a tool like `Triton`.
A simple Triton kernel for GEMV might look like this:
......@@ -39,7 +39,7 @@ def _gemv_naive(
`Triton` is straightforward to use, as it operates at the block level. However, this approach may not allow for fine-grained thread-level optimization. In this tutorial, we will demonstrate how to write an optimized GEMV kernel in `TileLang` that exposes more low-level control.
# Naive Implementation in TileLang
## Naive Implementation in TileLang
If you have a basic understanding of CUDA C, it is natural to start with a naive GEMV kernel by adapting a GEMM tiling strategy. You can think of GEMV as a `(1, k) * (k, n)` GEMM. Below is a simple example:
```python
......@@ -120,7 +120,7 @@ In this design, the first 128 threads act as the data producer and the last 128
At this level, we only gain very little computation power from our GPU with around **~0.17 ms** compared to torch/cuBLAS's **~0.008 ms**, which is around 20x slower.
# More concurrency
## More Concurrency
To further increase the concurrency of our kernel, we can exploit finer thread-level parallelism. Instead of assigning each thread to compute a single output element in C, you can introduce parallelism along the K dimension. Each thread computes a partial accumulation, and you then combine these partial results. This approach requires primitives like `atomicAdd` in CUDA.
......@@ -163,7 +163,7 @@ def naive_splitk_gemv(
By introducing parallelism along K dimension, our kernel now achieves **~0.024 ms**, an improvement, but still not on par with torch/cuBLAS.
## Customizing Parallelism in K Dimension
### Customizing Parallelism in K Dimension
If your K dimension is large, you can further customize how many elements each thread processes by introducing a `reduce_threads` parameter. This way, each thread handles multiple elements per iteration:
```python
......@@ -207,9 +207,9 @@ def splitk_gemv(
```
# Vectorized Reads
## Vectorized Reads
GEMV is less computation intensive than GEMM as the computation intensity and memory throuput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`:
GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`:
```python
def splitk_gemv_vectorized(
......@@ -255,7 +255,7 @@ def splitk_gemv_vectorized(
With vectorized read, now the kernel finishs in **~0.0084 ms**, which is getting close to cuBLAS performance.
# `tvm_thread_allreduce` Instead of `atomicAdd`
## `tvm_thread_allreduce` Instead of `atomicAdd`
[`tvm_thread_allreduce`](https://tvm.apache.org/docs/reference/api/python/tir/tir.html#tvm.tir.tvm_thread_allreduce) has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + `atomidAdd`:
......@@ -315,7 +315,7 @@ def splitk_gemv_vectorized_tvm(
With this optimization, the kernel latency now reduces from **~0.0084 ms** to **~0.0069 ms**, which is faster than torch/cuBLAS!
# Autotune
## Autotune
`BLOCK_N`, `BLOCK_K`, `reduce_threads` are hyperparameters in our kernel, which can be tuned to improve performance. We can use the `tilelang.autotune` feature to automatically search for optimal configurations:
......@@ -450,9 +450,9 @@ extern "C" __global__ void __launch_bounds__(64, 1) main_kernel(half_t* __restri
This corresponds closely to our `TileLang` program, with necessary synchronization and low-level optimizations inserted automatically.
# Conclusion
## Conclusion
## Benchmark Table on Hopper GPU
### Benchmark Table on Hopper GPU
| Kernel Name | Latency |
|------------|------------|
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment