attn_quant_fusion.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import Callable
6
from typing import Any, ParamSpec
7

8
9
import torch
import torch._inductor.pattern_matcher as pm
10
from torch import fx
11
12
13
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass

14
from vllm.config import VllmConfig, get_layers_from_vllm_config
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.attention import Attention
17
from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
    QuantKey,
19
    kNvfp4Dynamic,
20
21
    kStaticTensorScale,
)
22
from vllm.platforms import current_platform
23
from vllm.utils.math_utils import round_up
24

25
26
27
from ..fx_utils import is_func
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
28
from .matcher_utils import MatcherQuantFP8
29
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
30
31

logger = init_logger(__name__)
32
P = ParamSpec("P")
33
FP8_DTYPE = current_platform.fp8_dtype()
34
FP4_DTYPE = torch.uint8
35

36
37
38
39
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default


40
class AttentionQuantPattern(ABC):
41
    """
42
43
    The base class for Attn+Quant fusions.
    Should not be used directly.
44
    """
45
46
47

    def __init__(
        self,
48
        layer: Attention,
49
        quant_key: QuantKey,
50
        dtype: torch.dtype,
51
    ) -> None:
52
53
54
55
        self.layer = layer
        self.layer_name = layer.layer_name
        self.num_heads = layer.num_heads
        self.head_size = layer.head_size
56
57
        self.quant_key = quant_key
        self.quant_dtype = quant_key.dtype
58
        self.dtype = dtype
59

60
        assert self.quant_key in QUANT_OPS, (
61
            f"unsupported quantization scheme {self.quant_key}"
62
        )
63
64
        self.QUANT_OP = QUANT_OPS[self.quant_key]

65
    def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
66
        kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
67
68
        return torch.empty(*args, **kwargs)

69
    def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
70
        kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
71
72
        return torch.empty(*args, **kwargs)

73
    @staticmethod
74
75
76
77
78
    def wrap_trace_fn(
        trace_fn: Callable[P, fx.GraphModule],
        *process_fx_fns: Callable[[fx.GraphModule], None],
    ) -> Callable[P, fx.GraphModule]:
        def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
79
80
81
82
83
            gm = trace_fn(*args, **kwargs)
            for process_fx in process_fx_fns:
                process_fx(gm)

            return gm
84
85
86
87

        return wrapped

    @staticmethod
88
    def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
89
        from torch._inductor.fx_passes.post_grad import view_to_reshape
90

91
        view_to_reshape(gm)
92
93

    @staticmethod
94
    def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
95
96
97
98
99
100
101
102
103
104
105
        for node in gm.graph.nodes:
            if not is_func(node, torch.ops.aten.permute.default):
                continue

            dims = node.args[1]
            if any(dim != i for i, dim in enumerate(dims)):
                continue

            # this is now an identity op, remove
            node.replace_all_uses_with(node.args[0])
            gm.graph.erase_node(node)
106

107
    def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
108
        if self.layer.impl.fused_output_quant_supported(self.quant_key):
109
110
            self._register(pm_pass)

111
    @abstractmethod
112
    def _register(self, pm_pass: PatternMatcherPass) -> None:
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        raise NotImplementedError


class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Fp8StaticQuant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Fp8StaticQuant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """

    def __init__(
        self,
        layer: Attention,
129
        dtype: torch.dtype,
130
        symmetric: bool = True,
131
    ) -> None:
132
133
134
        quant_key = QuantKey(
            dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
        )
135
        super().__init__(layer, quant_key, dtype)
136
        self.quant_matcher = MatcherQuantFP8(quant_key)
137

138
    def _register(self, pm_pass: PatternMatcherPass) -> None:
139
140
141
142
143
144
        def pattern(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            scale: torch.Tensor,
145
            kv_cache_dummy_dep: torch.Tensor,
146
        ) -> torch.Tensor:
147
148
149
150
151
152
153
154
155
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=None,
                output_block_scale=None,
156
                kv_cache_dummy_dep=kv_cache_dummy_dep,
157
            )
158
            attn_out_view = RESHAPE_OP(
159
160
                at1[1], [q.shape[0], self.num_heads * self.head_size]
            )
161
162

            return self.quant_matcher(attn_out_view, scale)[0]
163

164
165
166
167
168
169
        def replacement(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            scale: torch.Tensor,
170
            kv_cache_dummy_dep: torch.Tensor,
171
        ) -> torch.Tensor:
172
173
174
175
176
            # attn output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size],
                0.0,
                dtype=self.quant_dtype,
177
178
179
180
181
182
183
184
185
186
187
                device=q.device,
            )
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=scale,
                output_block_scale=None,
188
                kv_cache_dummy_dep=kv_cache_dummy_dep,
189
            )
190
191
            return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])

192
        inputs = [
193
194
195
196
            self.empty(5, self.num_heads, self.head_size),  # q
            self.empty(5, self.num_heads, self.head_size),  # k
            self.empty(5, self.num_heads, self.head_size),  # v
            self.empty(5, self.num_heads, self.head_size),  # attn_output
197
            empty_fp32(1, 1),  # scale
198
            self.empty(0),  # kv_cache_dummy_dep
199
200
201
        ]

        pm.register_replacement(
202
203
204
            pattern,
            replacement,
            inputs,
205
            AttentionQuantPattern.wrap_trace_fn(
206
207
208
                pm.fwd_only,
                AttentionQuantPattern.fx_view_to_reshape,
                AttentionQuantPattern.remove_noop_permutes,
209
210
211
            ),
            pm_pass,
        )
212

213

214
215
216
217
218
219
220
221
222
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Nvfp4Quant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Nvfp4Quant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """
223

224
    def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
225
        super().__init__(layer, kNvfp4Dynamic, dtype)
226

227
    def _register(self, pm_pass: PatternMatcherPass) -> None:
228
229
230
231
232
233
234
235
        def pattern(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            output_quant: torch.Tensor,
            output_scale: torch.Tensor,
            input_scale: torch.Tensor,
236
            kv_cache_dummy_dep: torch.Tensor,
237
        ) -> tuple[torch.Tensor, torch.Tensor]:
238
239
240
241
242
243
244
245
246
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=None,
                output_block_scale=None,
247
                kv_cache_dummy_dep=kv_cache_dummy_dep,
248
            )
249
            attn_out_view = RESHAPE_OP(
250
251
252
253
254
255
256
257
                at1[1], [q.shape[0], self.num_heads * self.head_size]
            )
            at2 = auto_functionalized(
                self.QUANT_OP,
                output=output_quant,
                input=attn_out_view,
                output_scale=output_scale,
                input_scale=input_scale,
258
                is_sf_swizzled_layout=True,
259
            )
260
261
262
            output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
            return at2[1], output_scale_view

263
264
265
266
267
268
269
270
        def replacement(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            output_quant: torch.Tensor,
            output_scale: torch.Tensor,
            input_scale: torch.Tensor,
271
            kv_cache_dummy_dep: torch.Tensor,
272
        ) -> tuple[torch.Tensor, torch.Tensor]:
273
274
275
276
277
            # attention output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size // 2],
                0.0,
                dtype=self.quant_dtype,
278
279
                device=q.device,
            )
280
            # attention output block scale
281
282
283
284
285
286
287
288
289
290
            output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
            at2 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=input_scale,
                output_block_scale=output_scale_view,
291
                kv_cache_dummy_dep=kv_cache_dummy_dep,
292
293
            )
            output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
294
295
            return output, at2[2]

296
297
298
299
300
        inputs = [
            empty_bf16(5, self.num_heads, self.head_size),  # q
            empty_bf16(5, self.num_heads, self.head_size),  # k
            empty_bf16(5, self.num_heads, self.head_size),  # v
            empty_bf16(5, self.num_heads, self.head_size),  # output_attn
301
302
303
304
            self.empty_quant(5, self.num_heads * self.head_size // 2),  # output_quant
            empty_i32(
                128, round_up(self.num_heads * self.head_size // 16, 4)
            ),  # output_scale
305
            empty_fp32(1, 1),  # input_scale
306
            self.empty(0),  # kv_cache_dummy_dep
307
308
309
        ]

        pm.register_replacement(
310
311
312
            pattern,
            replacement,
            inputs,
313
            AttentionQuantPattern.wrap_trace_fn(
314
315
316
                pm.fwd_only,
                AttentionQuantPattern.fx_view_to_reshape,
                AttentionQuantPattern.remove_noop_permutes,
317
318
319
            ),
            pm_pass,
        )
320
321


322
class AttnFusionPass(VllmPatternMatcherPass):
323
324
325
326
327
328
329
330
331
332
333
334
    """
    This pass fuses post-attention quantization onto attention if supported.

    It uses the pattern matcher and matches each layer manually, as strings
    cannot be wildcarded. This also lets us check support on attention layers
    upon registration instead of during pattern matching.

    Currently, only static fp8 quant is supported, but patterns could easily be
    added for other quant schemes and dtypes. The bigger hurdle for wider
    support are attention kernels, which need to support fusing output quant.
    """

335
    @enable_fake_mode
336
    def __init__(self, config: VllmConfig) -> None:
337
338
339
340
        super().__init__(config)

        self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")

341
342
        attn_layers = get_layers_from_vllm_config(config, Attention)
        for layer_name, layer in attn_layers.items():
343
            pattern_fp8 = AttentionFp8StaticQuantPattern(
344
345
                layer, config.model_config.dtype
            )
346
347
            pattern_fp8.register_if_supported(self.patterns)

348
            if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
349
                pattern_nvfp4 = AttentionNvfp4QuantPattern(
350
351
                    layer, config.model_config.dtype
                )
352
                pattern_nvfp4.register_if_supported(self.patterns)
353

354
        if len(attn_layers) == 0:
355
            logger.warning(
356
357
                "Attention + quant fusion is enabled, but no attention layers "
                "were found in CompilationConfig.static_forward_context "
358
359
                "so no fusion patterns were registered."
            )
360

361
        self.dump_patterns(config, self.patterns)
362

363
364
365
366
    @VllmInductorPass.time_and_log
    def __call__(self, graph: torch.fx.graph.Graph) -> None:
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Fused quant onto %s attention nodes", self.matched_count)
367

368
    def uuid(self) -> str:
369
370
371
372
373
374
        return VllmInductorPass.hash_source(
            self,
            AttentionQuantPattern,
            AttentionFp8StaticQuantPattern,
            AttentionNvfp4QuantPattern,
        )