attn_quant_fusion.py 12.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import Callable
6
from typing import Any, ParamSpec
7

8
9
import torch
import torch._inductor.pattern_matcher as pm
10
from torch import fx
11
12
13
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass

14
from vllm.config import VllmConfig, get_layers_from_vllm_config
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.attention import Attention
17
from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
    QuantKey,
19
    kNvfp4Dynamic,
20
21
    kStaticTensorScale,
)
22
from vllm.platforms import current_platform
23
from vllm.utils.math_utils import round_up
24

25
26
27
from ..fx_utils import is_func
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
28
from .matcher_utils import MatcherQuantFP8
29
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
30
31

logger = init_logger(__name__)
32
P = ParamSpec("P")
33
FP8_DTYPE = current_platform.fp8_dtype()
34
FP4_DTYPE = torch.uint8
35

36
37
38
39
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default


40
class AttentionQuantPattern(ABC):
41
    """
42
43
    The base class for Attn+Quant fusions.
    Should not be used directly.
44
    """
45
46
47

    def __init__(
        self,
48
        layer: Attention,
49
        quant_key: QuantKey,
50
        dtype: torch.dtype,
51
    ) -> None:
52
53
54
55
        self.layer = layer
        self.layer_name = layer.layer_name
        self.num_heads = layer.num_heads
        self.head_size = layer.head_size
56
57
        self.quant_key = quant_key
        self.quant_dtype = quant_key.dtype
58
        self.dtype = dtype
59

60
        assert self.quant_key in QUANT_OPS, (
61
            f"unsupported quantization scheme {self.quant_key}"
62
        )
63
64
        self.QUANT_OP = QUANT_OPS[self.quant_key]

65
    def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
66
        kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
67
68
        return torch.empty(*args, **kwargs)

69
    def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
70
        kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
71
72
        return torch.empty(*args, **kwargs)

73
    @staticmethod
74
75
76
77
78
    def wrap_trace_fn(
        trace_fn: Callable[P, fx.GraphModule],
        *process_fx_fns: Callable[[fx.GraphModule], None],
    ) -> Callable[P, fx.GraphModule]:
        def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
79
80
81
82
83
            gm = trace_fn(*args, **kwargs)
            for process_fx in process_fx_fns:
                process_fx(gm)

            return gm
84
85
86
87

        return wrapped

    @staticmethod
88
    def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
89
        from torch._inductor.fx_passes.post_grad import view_to_reshape
90

91
        view_to_reshape(gm)
92
93

    @staticmethod
94
    def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
95
96
97
98
99
100
101
102
103
104
105
        for node in gm.graph.nodes:
            if not is_func(node, torch.ops.aten.permute.default):
                continue

            dims = node.args[1]
            if any(dim != i for i, dim in enumerate(dims)):
                continue

            # this is now an identity op, remove
            node.replace_all_uses_with(node.args[0])
            gm.graph.erase_node(node)
106

107
    def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
108
        if self.layer.impl.fused_output_quant_supported(self.quant_key):
109
110
            self._register(pm_pass)

111
    @abstractmethod
112
    def _register(self, pm_pass: PatternMatcherPass) -> None:
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        raise NotImplementedError


class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Fp8StaticQuant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Fp8StaticQuant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """

    def __init__(
        self,
        layer: Attention,
129
        dtype: torch.dtype,
130
        symmetric: bool = True,
131
    ) -> None:
132
133
134
        quant_key = QuantKey(
            dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
        )
135
        super().__init__(layer, quant_key, dtype)
136
        self.quant_matcher = MatcherQuantFP8(quant_key)
137

138
    def _register(self, pm_pass: PatternMatcherPass) -> None:
139
140
141
142
143
144
        def pattern(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            scale: torch.Tensor,
145
            kv_cache_dummy_dep: torch.Tensor,
146
        ) -> torch.Tensor:
147
148
149
150
151
152
153
154
155
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=None,
                output_block_scale=None,
156
                kv_cache_dummy_dep=kv_cache_dummy_dep,
157
            )
158
            attn_out_view = RESHAPE_OP(
159
160
                at1[1], [q.shape[0], self.num_heads * self.head_size]
            )
161
162

            return self.quant_matcher(attn_out_view, scale)[0]
163

164
165
166
167
168
169
        def replacement(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            scale: torch.Tensor,
170
            kv_cache_dummy_dep: torch.Tensor,
171
        ) -> torch.Tensor:
172
            # attn output in quant_dtype
173
            output_attn = torch.empty(
174
175
                [q.shape[0], self.num_heads, self.head_size],
                dtype=self.quant_dtype,
176
177
178
179
180
181
182
183
184
185
186
                device=q.device,
            )
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=scale,
                output_block_scale=None,
187
                kv_cache_dummy_dep=kv_cache_dummy_dep,
188
            )
189
190
            return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])

191
        inputs = [
192
193
194
195
            self.empty(5, self.num_heads, self.head_size),  # q
            self.empty(5, self.num_heads, self.head_size),  # k
            self.empty(5, self.num_heads, self.head_size),  # v
            self.empty(5, self.num_heads, self.head_size),  # attn_output
196
            empty_fp32(1, 1),  # scale
197
            self.empty(0),  # kv_cache_dummy_dep
198
199
200
        ]

        pm.register_replacement(
201
202
203
            pattern,
            replacement,
            inputs,
204
            AttentionQuantPattern.wrap_trace_fn(
205
206
207
                pm.fwd_only,
                AttentionQuantPattern.fx_view_to_reshape,
                AttentionQuantPattern.remove_noop_permutes,
208
209
210
            ),
            pm_pass,
        )
211

212

213
214
215
216
217
218
219
220
221
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Nvfp4Quant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Nvfp4Quant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """
222

223
    def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
224
        super().__init__(layer, kNvfp4Dynamic, dtype)
225

226
    def _register(self, pm_pass: PatternMatcherPass) -> None:
227
228
229
230
231
232
233
234
        def pattern(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            output_quant: torch.Tensor,
            output_scale: torch.Tensor,
            input_scale: torch.Tensor,
235
            kv_cache_dummy_dep: torch.Tensor,
236
        ) -> tuple[torch.Tensor, torch.Tensor]:
237
238
239
240
241
242
243
244
245
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=None,
                output_block_scale=None,
246
                kv_cache_dummy_dep=kv_cache_dummy_dep,
247
            )
248
            attn_out_view = RESHAPE_OP(
249
250
251
252
253
254
255
256
                at1[1], [q.shape[0], self.num_heads * self.head_size]
            )
            at2 = auto_functionalized(
                self.QUANT_OP,
                output=output_quant,
                input=attn_out_view,
                output_scale=output_scale,
                input_scale=input_scale,
257
                is_sf_swizzled_layout=True,
258
            )
259
260
261
            output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
            return at2[1], output_scale_view

262
263
264
265
266
267
268
269
        def replacement(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            output_quant: torch.Tensor,
            output_scale: torch.Tensor,
            input_scale: torch.Tensor,
270
            kv_cache_dummy_dep: torch.Tensor,
271
        ) -> tuple[torch.Tensor, torch.Tensor]:
272
            # attention output in quant_dtype
273
            output_attn = torch.empty(
274
275
                [q.shape[0], self.num_heads, self.head_size // 2],
                dtype=self.quant_dtype,
276
277
                device=q.device,
            )
278
            # attention output block scale
279
280
281
282
283
284
285
286
287
288
            output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
            at2 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=input_scale,
                output_block_scale=output_scale_view,
289
                kv_cache_dummy_dep=kv_cache_dummy_dep,
290
291
            )
            output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
292
293
            return output, at2[2]

294
295
296
297
298
        inputs = [
            empty_bf16(5, self.num_heads, self.head_size),  # q
            empty_bf16(5, self.num_heads, self.head_size),  # k
            empty_bf16(5, self.num_heads, self.head_size),  # v
            empty_bf16(5, self.num_heads, self.head_size),  # output_attn
299
300
301
302
            self.empty_quant(5, self.num_heads * self.head_size // 2),  # output_quant
            empty_i32(
                128, round_up(self.num_heads * self.head_size // 16, 4)
            ),  # output_scale
303
            empty_fp32(1, 1),  # input_scale
304
            self.empty(0),  # kv_cache_dummy_dep
305
306
307
        ]

        pm.register_replacement(
308
309
310
            pattern,
            replacement,
            inputs,
311
            AttentionQuantPattern.wrap_trace_fn(
312
313
314
                pm.fwd_only,
                AttentionQuantPattern.fx_view_to_reshape,
                AttentionQuantPattern.remove_noop_permutes,
315
316
317
            ),
            pm_pass,
        )
318
319


320
class AttnFusionPass(VllmPatternMatcherPass):
321
322
323
324
325
326
327
328
329
330
331
332
    """
    This pass fuses post-attention quantization onto attention if supported.

    It uses the pattern matcher and matches each layer manually, as strings
    cannot be wildcarded. This also lets us check support on attention layers
    upon registration instead of during pattern matching.

    Currently, only static fp8 quant is supported, but patterns could easily be
    added for other quant schemes and dtypes. The bigger hurdle for wider
    support are attention kernels, which need to support fusing output quant.
    """

333
    @enable_fake_mode
334
    def __init__(self, config: VllmConfig) -> None:
335
336
337
338
        super().__init__(config)

        self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")

339
340
        attn_layers = get_layers_from_vllm_config(config, Attention)
        for layer_name, layer in attn_layers.items():
341
            pattern_fp8 = AttentionFp8StaticQuantPattern(
342
343
                layer, config.model_config.dtype
            )
344
345
            pattern_fp8.register_if_supported(self.patterns)

346
            if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
347
                pattern_nvfp4 = AttentionNvfp4QuantPattern(
348
349
                    layer, config.model_config.dtype
                )
350
                pattern_nvfp4.register_if_supported(self.patterns)
351

352
        if len(attn_layers) == 0:
353
            logger.warning(
354
355
                "Attention + quant fusion is enabled, but no attention layers "
                "were found in CompilationConfig.static_forward_context "
356
357
                "so no fusion patterns were registered."
            )
358

359
        self.dump_patterns(config, self.patterns)
360

361
362
363
364
    @VllmInductorPass.time_and_log
    def __call__(self, graph: torch.fx.graph.Graph) -> None:
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Fused quant onto %s attention nodes", self.matched_count)
365

366
    def uuid(self) -> str:
367
368
369
370
371
372
        return VllmInductorPass.hash_source(
            self,
            AttentionQuantPattern,
            AttentionFp8StaticQuantPattern,
            AttentionNvfp4QuantPattern,
        )