fusions.md 19.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Fusion torch.compile passes

vLLM applies a set of kernel/operator fusions at compile time (via custom [`torch.compile`](torch_compile.md) Inductor passes)
to separate optimizations from model definitions and avoid breaking layer abstractions in model code.
These fusions are controlled by fields in [`PassConfig`][vllm.config.compilation.PassConfig] and are automatically enabled
at appropriate [optimization levels](optimization_levels.md).

## Quick Reference

The table below maps each fusion to its controlling flag/config knob, the
operations it fuses, what level enables it by default, and an indicative speedup.
The Fullgraph column indicates whether the fusion requires the entire model graph to be
visible (either via Inductor partition or `splitting_ops=[]`),
and the last column indicates whether the fusion activates for all `num_tokens`
or just on the low or high end.

!!! info
    Speedup depends heavily on the exact model, batch size, and hardware.
    If tuning performance by hand, always benchmark your exact use-case with and without the fusion to verify the impact.

| Fusion                                                                         | `PassConfig` flag            | Fused operations                               | Default at                     | E2E Speedup        | Fullgraph | `num_tokens` |
22
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms)                  | `fuse_allreduce_rms`         | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20%              | No        | Low          |
| [Attention + Quant](#attention--quantization-fuse_attn_quant)                  | `fuse_attn_quant`            | Attention output → FP8/NVFP4 quant             | Off by default                 | 3-7%               | Yes       | Always       |
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache)             | `fuse_rope_kvcache`          | Rotary embedding → KV cache write              | O1 (ROCm/AITER only)           | TBD                | No        | Low          |
| [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion)                    | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding                 | Off by default                 | 2-3%               | No        | Low          |
| [Sequence Parallelism](#sequence-parallelism-enable_sp)                        | `enable_sp`                  | AllReduce → ReduceScatter + AllGather          | Off by default                 | Prereq for AsyncTP | Yes       | High         |
| [AsyncTP GEMM + collective](#asynctp-gemm--collective-overlap-fuse_gemm_comms) | `fuse_gemm_comms`            | GEMM → reduce-scatter / all-gather → GEMM      | Off by default                 | 7-10%              | Yes       | High         |
| [RMSNorm + Quant](#rmsnorm--quantization-fuse_norm_quant)                      | `fuse_norm_quant`            | RMSNorm (+residual add) → FP8/FP4 quant        | O1 (conditional)               | 1-4%               | No        | Always       |
| [SiLU+Mul + Quant](#silumul--quantization-fuse_act_quant)                      | `fuse_act_quant`             | SiLU+Mul activation → FP8/FP4 quant            | O1 (conditional)               | 1-4%               | No        | Always       |
| [RMSNorm + Padding](#rmsnorm--padding-fuse_act_padding)                        | `fuse_act_padding`           | Residual add + RMSNorm → padding               | O1 (ROCm/AITER only)           | TBD                | No        | Always       |

## Support Matrix

The table below lists the quantization schemes supported by each fusion on each platform.
**—** means the fusion is not available on that platform. The latest and in-progress work is available in the tracking issue:
[#36066](https://github.com/vllm-project/vllm/issues/36066)

| Fusion                       | SM100 (Blackwell)                        | SM90 (Hopper)                            | SM89 (Ada)                               | SM80 (Ampere) | ROCm                                     |
40
| ---------------------------- | ---------------------------------------- | ---------------------------------------- | ---------------------------------------- | ------------- | ---------------------------------------- |
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
| `fuse_allreduce_rms`         | FP16/BF16, FP8 static, NVFP4             | FP16/BF16, FP8 static                    | —                                        | —             | —                                        |
| `fuse_attn_quant`\*          | FP8 static\*, NVFP4\*                    | FP8 static\*                             | FP8 static\*                             | —             | FP8 static\*                             |
| `fuse_rope_kvcache`          | —                                        | —                                        | —                                        | —             | FP16/BF16                                |
| `enable_qk_norm_rope_fusion` | FP16/BF16                                | FP16/BF16                                | FP16/BF16†                               | FP16/BF16†    | —                                        |
| `enable_sp`                  | FP16/BF16, FP8 static†                   | FP16/BF16, FP8 static                    | FP16/BF16†                               | FP16/BF16†    | —                                        |
| `fuse_gemm_comms`            | FP16/BF16, FP8 static†                   | FP16/BF16, FP8 static                    | FP16/BF16†                               | FP16/BF16†    | —                                        |
| `fuse_norm_quant`            | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | —             | FP8 static, FP8 per-token, FP8 per-group |
| `fuse_act_quant`             | FP8 static, NVFP4                        | FP8 static                               | FP8 static                               | —             | FP8 per-group                            |
| `fuse_act_padding`           | —                                        | —                                        | —                                        | —             | FP16/BF16                                |

\* `fuse_attn_quant` support depends on the attention backend in use; not all backends support
fused quantization output. See the [`fuse_attn_quant` section](#attention--quantization-fuse_attn_quant)
for per-backend details.

`enable_sp` and `fuse_gemm_comms` are only autoconfigured for SM90 today;
other architectures support requires setting `PassConfig.sp_min_token_num` explicitly.
SM100 support also requires setting `VLLM_DISABLED_KERNELS=FlashInferFP8ScaledMMLinearKernel`.

## Enabling / Disabling Fusions

Fusions are exposed through `PassConfig`, which is nested inside `CompilationConfig`:

```python
from vllm import LLM
from vllm.config import CompilationConfig, PassConfig

llm = LLM(
    model="...",
    optimization_level=2, # Default optimization level
    compilation_config=CompilationConfig(
        pass_config=PassConfig(
            fuse_norm_quant=True,
            fuse_act_quant=True,
            fuse_allreduce_rms=False,  # disable a specific fusion
        )
    ),
)
```

Fusions can also be enabled using command-line flags with any `vllm ...` command:

```bash
# Enable O2 defaults, but turn off allreduce fusion
vllm serve meta-llama/Llama-3.1-8B-Instruct -O2 -cc.pass_config.fuse_allreduce_rms=False

# The above is equivalent to the more verbose:
vllm serve meta-llama/Llama-3.1-8B-Instruct -O2 --compilation-config '{"pass_config": {"fuse_allreduce_rms": false}}'

# Same syntax in other commands, e.g. vllm bench:
vllm bench latency --model=meta-llama/Llama-3.1-8B-Instruct -O2 -cc.pass_config.fuse_allreduce_rms=False
```

Fields set explicitly by the user always take precedence over optimization-level defaults.

## Fusion Details

### AllReduce + RMSNorm (`fuse_allreduce_rms`)

!!! warning
    TP+DP and TP+PP combinations are currently broken
    ([#34458](https://github.com/vllm-project/vllm/issues/34458) and
    [#35426](https://github.com/vllm-project/vllm/issues/35426)).
    Only supported on NVIDIA Hopper (SM90) and Blackwell (SM100) with FlashInfer installed.

**What it fuses.** Fuses the tensor-parallel all-reduce collective with the subsequent residual add,
RMSNorm, and optionally a quantization step into a single FlashInfer / TRT-LLM communication kernel.
This fusion is only profitable for small `num_tokens`,
so the fusion is only performed in the lower compiled range.

Patterns covered:

- `AllReduce → RMSNorm(+residual_add)`: CUDA sm90+ with FlashInfer
- `AllReduce → RMSNorm(+residual_add) → FP8 static quant`: CUDA sm90+ with FlashInfer
- `AllReduce → RMSNorm(+residual_add) → NVFP4 dynamic quant`: CUDA sm100+ with FlashInfer

The maximum tensor size below which the fused kernel is used is hardware-dependent (64 MB for TP=2
on SM90/SM100) and configurable via `PassConfig.fi_allreduce_fusion_max_size_mb`.

**Code locations.**

- Pass: [`vllm/compilation/passes/fusion/allreduce_rms_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/allreduce_rms_fusion.py)
- FlashInfer all-reduce: [`vllm/distributed/device_communicators/flashinfer_all_reduce.py`](https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/flashinfer_all_reduce.py)
- Benchmark: [`benchmarks/kernels/benchmark_fused_collective.py`](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_fused_collective.py)

### Attention + Quantization (`fuse_attn_quant`)

!!! info
    `fuse_attn_quant` is currently not enabled at any optimization level by default and must be set
    explicitly. It requires the full model graph to be visible (Inductor partition or `splitting_ops=[]`).

**What it fuses.** Fuses the attention output quantization directly after the attention computation,
eliminating a full-precision memory round-trip of the attention output. Patterns covered:

`Attention → FP8 static quant`:

- `TRITON_ATTN`: CUDA, ROCm
- `FLASHINFER`: CUDA sm100+ with FlashInfer installed
- `ROCM_ATTN`: ROCm
- `ROCM_AITER_UNIFIED_ATTN`: ROCm with AITER

`Attention → NVFP4 dynamic quant`:

- `FLASHINFER`: CUDA sm100+ with FlashInfer installed

Other attention backends do not support fused output quantization yet.

**Code locations.**

- Pass: [`vllm/compilation/passes/fusion/attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/attn_quant_fusion.py)
- Attention backends: [`vllm/v1/attention/backends/`](https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/)

### RoPE + KV-Cache Update (`fuse_rope_kvcache`)

!!! info
    ROCm/AITER-only. Not available on NVIDIA CUDA or CPU. The fusion is only enabled for
    `num_tokens ≤ 256` by default due to AITER fused kernel performance issues.
    This threshold is configurable via `PassConfig.rope_kvcache_fusion_max_token_num`.

**What it fuses.** Fuses the rotary positional embedding kernel with the KV-cache scatter/write into
a single kernel, avoiding separate reads and writes of the key and value tensors.

Requires: AMD ROCm with AITER enabled, the `rotary_embedding` custom op active (automatic),
and the `kv_cache` update op visible in the graph: either by using Inductor graph partition
or removed from `splitting_ops`.
If these conditions are set, the fusion is enabled automatically for optimization level O1 and above.

**Code locations.**

- Pass: [`vllm/compilation/passes/fusion/rope_kvcache_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rope_kvcache_fusion.py)

### Sequence Parallelism (`enable_sp`)

**What it fuses.** Replaces all-reduce collectives with reduce-scatter + local RMSNorm + all-gather,
splitting the sequence dimension across TP ranks. This restructures the graph so the subsequent AsyncTP
pass can fuse the reduce-scatter / all-gather with the surrounding GEMMs.

Sequence Parallelism itself does not directly improve performance; it is a prerequisite for the
AsyncTP pass (`fuse_gemm_comms`). SP is only applied above a minimum token threshold that is
autoconfigured based on device capability and model `hidden_size`. Currently only active on
H100/SM90 for models with `hidden_size >= 8192`. The threshold is configurable via
`PassConfig.sp_min_token_num`.

The general transformation:

```text
Input → AllReduce → RMSNorm → Output
becomes:
Input → ReduceScatter → local RMSNorm → AllGather → Output
```

Patterns covered:

- First block: `AllReduce → RMSNorm``ReduceScatter → RMSNorm → AllGather`
- Middle blocks: `AllReduce → fused_add_RMSNorm``ReduceScatter → fused_add_RMSNorm → AllGather`
- Both with optional `→ FP8 static quant` suffix

Requires: `use_inductor_graph_partition=True` **or** piecewise compilation with static sizes
divisible by `tensor_parallel_size`.

Supported hardware: Only tested on NVIDIA CUDA, possibly works on ROCm. FP8 all-gather requires sm90+.

**Code locations.**

- Pass: [`vllm/compilation/passes/fusion/sequence_parallelism.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/sequence_parallelism.py)

### AsyncTP GEMM + Collective Overlap (`fuse_gemm_comms`)

!!! info
    Requires `enable_sp=True` (enabled automatically). This pass is a no-op if Sequence Parallelism has not been applied.

**What it fuses.** After Sequence Parallelism transforms the graph, fuses GEMM kernels with the
surrounding reduce-scatter (output projection) and all-gather (input projection) using
`torch.ops.symm_mem` symmetric-memory primitives, overlapping communication and computation.
This overlap is only profitable for large `num_tokens`, so the fusion (and preceding SP)
is only performed in the higher compiled range above `PassConfig.sp_min_token_num`.

Patterns covered:

- `GEMM → reduce-scatter``fused_matmul_reduce_scatter`
- `all-gather → GEMM``all_gather_matmul`
- FP8 scaled variants of both patterns

Supported hardware: NVIDIA CUDA with symmetric-memory (`torch.distributed._symmetric_memory`) support.

On B200, pattern-matching fp8 FlashInfer scaled MM is not supported, so it must be disabled
([#27893](https://github.com/vllm-project/vllm/issues/27893))

```shell
VLLM_DISABLED_KERNELS=FlashInferFP8ScaledMMLinearKernel ...
```

**Code locations.**

- Pass: [`vllm/compilation/passes/fusion/collective_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/collective_fusion.py)
- Sequence parallelism pass: [`vllm/compilation/passes/fusion/sequence_parallelism.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/sequence_parallelism.py)

### QK Norm + RoPE (`enable_qk_norm_rope_fusion`)

!!! info
    Only applicable to models that apply per-head RMSNorm to Q and K before rotary positional
    embedding (e.g. Qwen). Not enabled by default at any optimization level due to perf issues on H100:
    [#34391](https://github.com/vllm-project/vllm/issues/34391)

**What it fuses.** Fuses the sequence: split QKV → reshape → Q/K RMSNorm → reshape → rotary
embedding into a single `fused_qk_norm_rope` CUDA kernel.

```text
# Unfused:
q, k, v = split(qkv)
q_norm = rms_norm(q.view(heads))
k_norm = rms_norm(k.view(kv_heads))
q_rope, k_rope = rotary_embedding(q_norm, k_norm, ...)

# Fused:
fused_qk_norm_rope(qkv, ...)
```

Supported hardware: CUDA (sm80+) only, tested only on sm90 and sm100.

**Code locations.**

- Pass: [`vllm/compilation/passes/fusion/qk_norm_rope_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py)
- CUDA kernel: [`csrc/ops.h`](https://github.com/vllm-project/vllm/blob/main/csrc/ops.h) (`fused_qk_norm_rope`)

### RMSNorm + Quantization (`fuse_norm_quant`)

!!! warning
    On NVIDIA, Inductor actually generates a faster fused kernel than our custom CUDA kernel.
    Hence, this fusion is only enabled when either `rms_norm` or `quant_fp8` is using a custom kernel.

**What it fuses.** Combines the custom `rms_norm` / `fused_add_rms_norm`
operations with subsequent quantization into a single fused kernel,
eliminating an intermediate read/write of the full-precision activation tensor.
Two variants are fused:

- *Plain RMSNorm + quant*: `rms_norm(x) → quant_fp8(y)`
- *Fused-add RMSNorm + quant*: `fused_add_rms_norm(x, residual) → quant_fp8(y)` — also updates the residual in-place.

Note that AITER fusions are currently in a separate pass in `vllm.compilation.passes.fusion.rocm_aiter_fusion`.

Supported quantization scheme/hardware combinations:

- FP8 static per-tensor: CUDA & HIP kernel
- FP8 dynamic per-token: CUDA & HIP kernel, AITER
- FP8 dynamic per-token-group (128/64): CUDA & HIP kernel, AITER

**Code locations.**

- Pass: [`vllm/compilation/passes/fusion/rms_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rms_quant_fusion.py)
- ROCm AITER pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py)
- CUDA/HIP kernels: [`csrc/layernorm_quant_kernels.cu`](https://github.com/vllm-project/vllm/blob/main/csrc/layernorm_quant_kernels.cu)

### SiLU+Mul + Quantization (`fuse_act_quant`)

!!! warning
    Same as `fuse_norm_quant`: on NVIDIA, Inductor generates a faster fused kernel than our custom ops.
    This fusion is only enabled when either `silu_and_mul` or `quant_fp8` are using a custom kernel,
    or for NVFP4-quantized models (where FP4 quant is always a custom op).

**What it fuses.** Fuses the `silu_and_mul` gate-up projection activation with subsequent quantization into a single kernel,
avoiding materialization of the full-precision post-activation tensor.

Note that AITER fusions are in a separate pass in `vllm.compilation.passes.fusion.rocm_aiter_fusion`.

Supported quantization scheme/hardware combinations:

- FP8 static per-tensor: CUDA & HIP kernel
- NVFP4 dynamic: CUDA sm100+ only with FlashInfer
- FP8 per-token-group (128): ROCm AITER only

**Code locations.**

- Pass: [`vllm/compilation/passes/fusion/act_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/act_quant_fusion.py)
- ROCm AITER pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py)
- CUDA/HIP kernels: [`csrc/quantization/`](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/)

### RMSNorm + Padding (`fuse_act_padding`)

!!! info
    ROCm/AITER-only. Targeted at GPT-OSS models.

**What it fuses.** Fuses a residual add + RMSNorm with a subsequent padding operation that pads
the hidden dimension to a multiple required by downstream AITER Triton GEMM kernels.

Requires: AMD ROCm with AITER RMSNorm enabled. Enabled by default in optimization level O1 and above
when the hidden size is 2880 and AITER Triton GEMMs *not* enabled.

**Code locations.**

- Pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py) (`RocmAiterTritonAddRMSNormPadFusionPass`)

## See Also

- [Optimization Levels](optimization_levels.md) — high-level presets that set
  fusion defaults.
- [torch.compile in vLLM](torch_compile.md) — how the Inductor pass pipeline
  works.
- [Attention Backends](attention_backends.md) — attention-specific kernel
  selection.