attn_quant_fusion.py 14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
from collections.abc import Callable
6

7
8
9
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized

10
from vllm.config import VllmConfig, get_layers_from_vllm_config
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.attention import Attention
13
from vllm.model_executor.layers.quantization.utils.quant_utils import (
14
    QuantKey,
15
    kNvfp4Dynamic,
16
17
    kStaticTensorScale,
)
18
from vllm.platforms import current_platform
19
from vllm.utils.math_utils import round_up
20
from vllm.utils.torch_utils import _USE_LAYERNAME, _encode_layer_name
21

22
from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
23
from .matcher_utils import MatcherQuantFP8
24
from .rms_quant_fusion import QUANT_OPS
25
26

logger = init_logger(__name__)
27

28
FP8_DTYPE = current_platform.fp8_dtype()
29
FP4_DTYPE = torch.uint8
30

31
32
33
34
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default


35
_FP8_QUANT_KEY = QuantKey(dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=True)
36

37
38

class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
39
40
41
42
43
44
45
46
47
    """
    Fusion for Attention+Fp8StaticQuant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Fp8StaticQuant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """

48
49
50
51
52
53
54
55
56
    def __init__(self, layer: Attention, dtype: torch.dtype):
        self._layer_name = layer.layer_name
        self._num_heads = layer.num_heads
        self._head_size = layer.head_size
        self._dtype = dtype
        self._quant_matcher = MatcherQuantFP8(_FP8_QUANT_KEY)

    @property
    def pattern(self) -> Callable[..., torch.Tensor]:
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        # When _USE_LAYERNAME is enabled (torch >= 2.11), layer_name is
        # passed as an explicit pattern input so the pattern matcher
        # treats it as a wildcard matching hoisted LayerName placeholders.
        # Otherwise it stays as a closure constant (original behavior).
        _ln = _encode_layer_name(self._layer_name)

        if _USE_LAYERNAME:

            def _pattern_with_ln(  # type: ignore[misc]
                q, k, v, output_attn, scale, kv_cache_dummy_dep, layer_name
            ):
                at1 = auto_functionalized(
                    ATTN_OP,
                    query=q,
                    key=k,
                    value=v,
                    output=output_attn,
                    layer_name=layer_name,
                    output_scale=None,
                    output_block_scale=None,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
                attn_out_view = RESHAPE_OP(
                    at1[1], [q.shape[0], self._num_heads * self._head_size]
                )
                return self._quant_matcher(attn_out_view, scale)[0]

            return _pattern_with_ln

        def _pattern(q, k, v, output_attn, scale, kv_cache_dummy_dep):
87
88
89
90
91
92
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
93
                layer_name=_ln,
94
95
                output_scale=None,
                output_block_scale=None,
96
                kv_cache_dummy_dep=kv_cache_dummy_dep,
97
            )
98
            attn_out_view = RESHAPE_OP(
99
                at1[1], [q.shape[0], self._num_heads * self._head_size]
100
            )
101
            return self._quant_matcher(attn_out_view, scale)[0]
102

103
        return _pattern
104

105
106
    @property
    def replacement(self) -> Callable[..., torch.Tensor]:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        _ln = _encode_layer_name(self._layer_name)

        if _USE_LAYERNAME:

            def _replacement_with_ln(  # type: ignore[misc]
                q, k, v, output_attn, scale, kv_cache_dummy_dep, layer_name
            ):
                output_attn = torch.empty(
                    [q.shape[0], self._num_heads, self._head_size],
                    dtype=FP8_DTYPE,
                    device=q.device,
                )
                at1 = auto_functionalized(
                    ATTN_OP,
                    query=q,
                    key=k,
                    value=v,
                    output=output_attn,
                    layer_name=layer_name,
                    output_scale=scale,
                    output_block_scale=None,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
                return RESHAPE_OP(at1[1], [-1, self._num_heads * self._head_size])

            return _replacement_with_ln

        def _replacement(q, k, v, output_attn, scale, kv_cache_dummy_dep):
135
            output_attn = torch.empty(
136
137
                [q.shape[0], self._num_heads, self._head_size],
                dtype=FP8_DTYPE,
138
139
140
141
142
143
144
145
                device=q.device,
            )
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
146
                layer_name=_ln,
147
148
                output_scale=scale,
                output_block_scale=None,
149
                kv_cache_dummy_dep=kv_cache_dummy_dep,
150
            )
151
152
153
154
155
156
157
158
            return RESHAPE_OP(at1[1], [-1, self._num_heads * self._head_size])

        return _replacement

    def get_inputs(self):
        dtype = self._dtype
        num_heads = self._num_heads
        head_size = self._head_size
159
        inputs: list = [
160
161
162
163
164
165
            self.empty(5, num_heads, head_size, dtype=dtype),  # q
            self.empty(5, num_heads, head_size, dtype=dtype),  # k
            self.empty(5, num_heads, head_size, dtype=dtype),  # v
            self.empty(5, num_heads, head_size, dtype=dtype),  # attn_output
            self.empty_fp32(1, 1),  # scale
            self.empty(0, dtype=dtype),  # kv_cache_dummy_dep
166
        ]
167
168
169
        if _USE_LAYERNAME:
            inputs.append(_encode_layer_name(self._layer_name))
        return inputs
170

171

172
173
174
class AttnNvfp4QuantPattern(
    VllmPatternReplacement[..., tuple[torch.Tensor, torch.Tensor]]
):
175
176
177
178
179
180
181
182
    """
    Fusion for Attention+Nvfp4Quant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Nvfp4Quant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """
183

184
185
186
187
188
189
    def __init__(self, layer: Attention, dtype: torch.dtype):
        self._layer_name = layer.layer_name
        self._num_heads = layer.num_heads
        self._head_size = layer.head_size
        self._dtype = dtype
        self._QUANT_OP = QUANT_OPS[kNvfp4Dynamic]
190

191
192
    @property
    def pattern(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        _ln = _encode_layer_name(self._layer_name)

        if _USE_LAYERNAME:

            def _pattern_with_ln(  # type: ignore[misc]
                q,
                k,
                v,
                output_attn,
                output_quant,
                output_scale,
                input_scale,
                kv_cache_dummy_dep,
                layer_name,
            ):
                at1 = auto_functionalized(
                    ATTN_OP,
                    query=q,
                    key=k,
                    value=v,
                    output=output_attn,
                    layer_name=layer_name,
                    output_scale=None,
                    output_block_scale=None,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
                attn_out_view = RESHAPE_OP(
                    at1[1], [q.shape[0], self._num_heads * self._head_size]
                )
                at2 = auto_functionalized(
                    self._QUANT_OP,
                    input=attn_out_view,
                    input_scale=input_scale,
                    is_sf_swizzled_layout=True,
                    output=output_quant,
                    output_scale=output_scale,
                )
                return at2[1], torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)

            return _pattern_with_ln

234
        def _pattern(
235
236
237
238
239
240
241
242
243
            q,
            k,
            v,
            output_attn,
            output_quant,
            output_scale,
            input_scale,
            kv_cache_dummy_dep,
        ):
244
245
246
247
248
249
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
250
                layer_name=_ln,
251
252
                output_scale=None,
                output_block_scale=None,
253
                kv_cache_dummy_dep=kv_cache_dummy_dep,
254
            )
255
            attn_out_view = RESHAPE_OP(
256
                at1[1], [q.shape[0], self._num_heads * self._head_size]
257
258
            )
            at2 = auto_functionalized(
259
                self._QUANT_OP,
260
261
                input=attn_out_view,
                input_scale=input_scale,
262
                is_sf_swizzled_layout=True,
263
264
                output=output_quant,
                output_scale=output_scale,
265
            )
266
            return at2[1], torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
267

268
269
270
271
        return _pattern

    @property
    def replacement(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        _ln = _encode_layer_name(self._layer_name)

        if _USE_LAYERNAME:

            def _replacement_with_ln(  # type: ignore[misc]
                q,
                k,
                v,
                output_attn,
                _output_quant,
                output_scale,
                input_scale,
                kv_cache_dummy_dep,
                layer_name,
            ):
                output_attn = torch.empty(
                    [q.shape[0], self._num_heads, self._head_size // 2],
                    dtype=FP4_DTYPE,
                    device=q.device,
                )
                osv = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
                at2 = auto_functionalized(
                    ATTN_OP,
                    query=q,
                    key=k,
                    value=v,
                    output=output_attn,
                    layer_name=layer_name,
                    output_scale=input_scale,
                    output_block_scale=osv,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
                return RESHAPE_OP(
                    at2[1], [-1, self._num_heads * self._head_size // 2]
                ), at2[2]

            return _replacement_with_ln

310
        def _replacement(
311
312
313
314
315
316
317
318
319
            q,
            k,
            v,
            output_attn,
            _output_quant,
            output_scale,
            input_scale,
            kv_cache_dummy_dep,
        ):
320
            output_attn = torch.empty(
321
322
                [q.shape[0], self._num_heads, self._head_size // 2],
                dtype=FP4_DTYPE,
323
324
                device=q.device,
            )
325
            osv = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
326
327
328
329
330
331
            at2 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
332
                layer_name=_ln,
333
                output_scale=input_scale,
334
                output_block_scale=osv,
335
                kv_cache_dummy_dep=kv_cache_dummy_dep,
336
            )
337
338
339
            return RESHAPE_OP(
                at2[1], [-1, self._num_heads * self._head_size // 2]
            ), at2[2]
340

341
342
343
344
345
346
        return _replacement

    def get_inputs(self):
        dtype = self._dtype
        num_heads = self._num_heads
        head_size = self._head_size
347
        inputs: list = [
348
349
350
351
            self.empty_bf16(5, num_heads, head_size),  # q
            self.empty_bf16(5, num_heads, head_size),  # k
            self.empty_bf16(5, num_heads, head_size),  # v
            self.empty_bf16(5, num_heads, head_size),  # output_attn
352
353
            self.empty(5, num_heads * head_size // 2, dtype=FP4_DTYPE),
            self.empty_i32(128, round_up(num_heads * head_size // 16, 4)),
354
355
            self.empty_fp32(1, 1),  # input_scale
            self.empty(0, dtype=dtype),  # kv_cache_dummy_dep
356
        ]
357
358
359
        if _USE_LAYERNAME:
            inputs.append(_encode_layer_name(self._layer_name))
        return inputs
360

361

362
class AttnQuantFusionPass(VllmFusionPatternMatcherPass):
363
364
365
366
367
368
369
370
371
372
373
374
    """
    This pass fuses post-attention quantization onto attention if supported.

    It uses the pattern matcher and matches each layer manually, as strings
    cannot be wildcarded. This also lets us check support on attention layers
    upon registration instead of during pattern matching.

    Currently, only static fp8 quant is supported, but patterns could easily be
    added for other quant schemes and dtypes. The bigger hurdle for wider
    support are attention kernels, which need to support fusing output quant.
    """

375
    def __init__(self, config: VllmConfig) -> None:
376
        super().__init__(config, "attn_quant_fusion")
377

378
379
        dtype = config.model_config.dtype
        layers = list(get_layers_from_vllm_config(config, Attention).values())
380

381
        if len(layers) == 0:
382
            logger.warning(
383
384
                "Attention + quant fusion is enabled, but no attention layers "
                "were found in CompilationConfig.static_forward_context "
385
386
                "so no fusion patterns were registered."
            )
387

388
389
        # When _USE_LAYERNAME is enabled, layer_name is a wildcard so all
        # layers produce the same pattern — register once then break.
390
391
392
        for layer in layers:
            if layer.impl.fused_output_quant_supported(_FP8_QUANT_KEY):
                self.register(AttnFp8StaticQuantPattern(layer, dtype))
393
394
                if _USE_LAYERNAME:
                    break
395

396
397
398
399
        if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
            for layer in layers:
                if layer.impl.fused_output_quant_supported(kNvfp4Dynamic):
                    self.register(AttnNvfp4QuantPattern(layer, dtype))
400
401
                    if _USE_LAYERNAME:
                        break
402

403
        self.dump_patterns(config, self.pm_pass)