"cmake/vscode:/vscode.git/clone" did not exist on "7e923180c17c33bbe51c22c883cc67a64934da5f"
Unverified Commit d59a4782 authored by jacky.cheng's avatar jacky.cheng Committed by GitHub
Browse files

[3rdparty, document] Updated Documentation that covers performance tuning...


[3rdparty, document] Updated Documentation that covers performance tuning techniques for AMD Instinct GPUs. (#1871)
Co-authored-by: default avatarroot <root@dell300x-pla-t10-23.pla.dcgpu>
parent 104bf260
...@@ -2,12 +2,100 @@ ...@@ -2,12 +2,100 @@
This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs. This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs.
Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads. Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads.
Three primary runtime areas are covered: Three primary runtime areas are covered:
- Triton Kernels
## 1. Triton Kernels
To maximize Triton kernel efficiency, several strategies can be employed:
- Torch Tunable Ops ### Key Environment Variables:
- **num_stages**: Adjusts the number of pipeline stages to optimize kernel efficiency based on the specific type of operations (e.g., General Matrix Multiplication - GEMM).
- **waves_per_eu**: Controls the usage of Vector General Purpose Registers (VGPR) to enhance occupancy, thereby improving latency or throughput.
- **BLOCK_M, BLOCK_N, BLOCK_K**: Tunable tile sizes that assist in balancing memory transfer and computational efficiency.
- **matrix_instr_nonkdim**: Optimizes the usage of Matrix-Fused Multiply-Add (MFMA) instructions for specific kernel types, such as Flash Attention.
- **OPTIMIZE_EPILOGUE**: An environment variable that can be set to `1` to enhance performance by eliminating the `convert_layout` operation in the kernel's epilogue.
```python
@triton.autotune(configs=[
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
], key=['BLOCK_N', 'NUM_TOKEN_BLKS'], use_cuda_graph=True)
@triton.jit
def _triton_kernel_funtion():
...
```
## 2. Torch Tunable Operations
**TunableOp** is a feature in PyTorch that allows for the definition and optimization of custom kernels with tunable parameters. This feature is particularly useful for enhancing the performance of kernels by experimenting with different configurations.
### Key Environment Variables:
1. **PYTORCH_TUNABLEOP_ENABLED**:
- Default: `0`
- Set to `1` to enable TunableOp.
- Torch Compile 2. **PYTORCH_TUNABLEOP_TUNING**:
- Default: `1`
- Set to `0` to disable tuning. If a tuned entry is not found, it will run the tuning step and record the entry when PYTORCH_TUNABLEOP_ENABLED is enabled.
3. **PYTORCH_TUNABLEOP_VERBOSE**:
- Default: `0`
- Set to `1` to enable verbose output for TunableOp.
### Usage Example:
To enable TunableOp and tuning, and optionally enable verbose mode, you can run the following command in your terminal:
```bash
#Tuning
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=1 your_script.sh
#Inference with tuning op
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_script.sh
#Print out the log
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 PYTORCH_TUNABLEOP_VERBOSE=1 your_script.sh
```
## 3. Torch Compilation
The following are suggestions for optimizing matrix multiplication (GEMM) and convolution (conv) operations in PyTorch using Inductor, a part of the PyTorch compilation framework. The goal is to leverage Triton to achieve better performance.
To tune Triton kernels with GEMM and convolution ops (conv), use the `torch.compile` function with the max-autotune mode. This benchmarks a predefined list of Triton configurations and selects the fastest one for each shape.
### Key Configurations:
1. **Max Autotune**:
- Set `torch._inductor.config.max_autotune = True` or `TORCHINDUCTOR_MAX_AUTOTUNE=1`.
2. **Fine-Grained Control**:
- Enable GEMM tuning: `torch._inductor.config.max_autotune_gemm = True`.
- Enable tuning for pointwise/reduction ops: `torch._inductor.config.max_autotune.pointwise = True`.
3. **Backend Selection**:
- Use `torch._inductor.max_autotune_gemm_backends` to limit backends to TRITON for better performance.
4. **Freezing for Inference**:
- Use `torch._inductor.config.freezing=True` to enable constant folding optimizations.
5. **Debugging**:
- Set `TORCH_COMPILE_DEBUG=1` to extract Triton kernels generated by Inductor.
### Example Code Block:
```bash
#Gemm Tuning
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 your_script.sh
#Specify your backend to TRITON for Gemm Tuning
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON your_script.sh
#Inference with large improvement on AMD GPU
TORCHINDUCTOR_FREEZING=1 your_script.sh
```
## Reference
For more detailed information on tuning SGLang performance with AMD GPUs, please refer to the following link:
[ROCm Documentation: Triton Kernel Performance Optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#triton-kernel-performance-optimization)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment