fusion_attn.py 12.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import Callable
6
from typing import Any, ParamSpec
7

8
9
import torch
import torch._inductor.pattern_matcher as pm
10
from torch import fx
11
12
13
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass

14
from vllm.attention.layer import Attention
15
from vllm.config import VllmConfig, get_layers_from_vllm_config
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
    QuantKey,
19
    kNvfp4Dynamic,
20
21
    kStaticTensorScale,
)
22
from vllm.platforms import current_platform
23
from vllm.utils.math_utils import round_up
24

25
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
26
from .fx_utils import is_func
27
from .inductor_pass import enable_fake_mode
28
from .matcher_utils import MatcherQuantFP8
29
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
30
31

logger = init_logger(__name__)
32
P = ParamSpec("P")
33
FP8_DTYPE = current_platform.fp8_dtype()
34
FP4_DTYPE = torch.uint8
35

36
37
38
39
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default


40
class AttentionQuantPattern(ABC):
41
    """
42
43
    The base class for Attn+Quant fusions.
    Should not be used directly.
44
    """
45
46
47

    def __init__(
        self,
48
        layer: Attention,
49
        quant_key: QuantKey,
50
        dtype: torch.dtype,
51
    ) -> None:
52
53
54
55
        self.layer = layer
        self.layer_name = layer.layer_name
        self.num_heads = layer.num_heads
        self.head_size = layer.head_size
56
57
        self.quant_key = quant_key
        self.quant_dtype = quant_key.dtype
58
        self.dtype = dtype
59

60
        assert self.quant_key in QUANT_OPS, (
61
            f"unsupported quantization scheme {self.quant_key}"
62
        )
63
64
        self.QUANT_OP = QUANT_OPS[self.quant_key]

65
    def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
66
        kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
67
68
        return torch.empty(*args, **kwargs)

69
    def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
70
        kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
71
72
        return torch.empty(*args, **kwargs)

73
    @staticmethod
74
75
76
77
78
    def wrap_trace_fn(
        trace_fn: Callable[P, fx.GraphModule],
        *process_fx_fns: Callable[[fx.GraphModule], None],
    ) -> Callable[P, fx.GraphModule]:
        def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
79
80
81
82
83
            gm = trace_fn(*args, **kwargs)
            for process_fx in process_fx_fns:
                process_fx(gm)

            return gm
84
85
86
87

        return wrapped

    @staticmethod
88
    def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
89
        from torch._inductor.fx_passes.post_grad import view_to_reshape
90

91
        view_to_reshape(gm)
92
93

    @staticmethod
94
    def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
95
96
97
98
99
100
101
102
103
104
105
        for node in gm.graph.nodes:
            if not is_func(node, torch.ops.aten.permute.default):
                continue

            dims = node.args[1]
            if any(dim != i for i, dim in enumerate(dims)):
                continue

            # this is now an identity op, remove
            node.replace_all_uses_with(node.args[0])
            gm.graph.erase_node(node)
106

107
    def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
108
        if self.layer.impl.fused_output_quant_supported(self.quant_key):
109
110
            self._register(pm_pass)

111
    @abstractmethod
112
    def _register(self, pm_pass: PatternMatcherPass) -> None:
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        raise NotImplementedError


class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Fp8StaticQuant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Fp8StaticQuant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """

    def __init__(
        self,
        layer: Attention,
129
        dtype: torch.dtype,
130
        symmetric: bool = True,
131
    ) -> None:
132
133
134
        quant_key = QuantKey(
            dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
        )
135
        super().__init__(layer, quant_key, dtype)
136
        self.quant_matcher = MatcherQuantFP8(quant_key)
137

138
    def _register(self, pm_pass: PatternMatcherPass) -> None:
139
140
141
142
143
144
        def pattern(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            scale: torch.Tensor,
145
        ) -> torch.Tensor:
146
147
148
149
150
151
152
153
154
155
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=None,
                output_block_scale=None,
            )
156
            attn_out_view = RESHAPE_OP(
157
158
                at1[1], [q.shape[0], self.num_heads * self.head_size]
            )
159
160

            return self.quant_matcher(attn_out_view, scale)[0]
161

162
163
164
165
166
167
        def replacement(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            scale: torch.Tensor,
168
        ) -> torch.Tensor:
169
170
171
172
173
            # attn output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size],
                0.0,
                dtype=self.quant_dtype,
174
175
176
177
178
179
180
181
182
183
184
185
                device=q.device,
            )
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=scale,
                output_block_scale=None,
            )
186
187
            return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])

188
        inputs = [
189
190
191
192
            self.empty(5, self.num_heads, self.head_size),  # q
            self.empty(5, self.num_heads, self.head_size),  # k
            self.empty(5, self.num_heads, self.head_size),  # v
            self.empty(5, self.num_heads, self.head_size),  # attn_output
193
            empty_fp32(1, 1),  # scale
194
195
196
        ]

        pm.register_replacement(
197
198
199
            pattern,
            replacement,
            inputs,
200
            AttentionQuantPattern.wrap_trace_fn(
201
202
203
                pm.fwd_only,
                AttentionQuantPattern.fx_view_to_reshape,
                AttentionQuantPattern.remove_noop_permutes,
204
205
206
            ),
            pm_pass,
        )
207

208

209
210
211
212
213
214
215
216
217
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Nvfp4Quant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Nvfp4Quant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """
218

219
    def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
220
        super().__init__(layer, kNvfp4Dynamic, dtype)
221

222
    def _register(self, pm_pass: PatternMatcherPass) -> None:
223
224
225
226
227
228
229
230
        def pattern(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            output_quant: torch.Tensor,
            output_scale: torch.Tensor,
            input_scale: torch.Tensor,
231
        ) -> tuple[torch.Tensor, torch.Tensor]:
232
233
234
235
236
237
238
239
240
241
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=None,
                output_block_scale=None,
            )
242
            attn_out_view = RESHAPE_OP(
243
244
245
246
247
248
249
250
                at1[1], [q.shape[0], self.num_heads * self.head_size]
            )
            at2 = auto_functionalized(
                self.QUANT_OP,
                output=output_quant,
                input=attn_out_view,
                output_scale=output_scale,
                input_scale=input_scale,
251
                is_sf_swizzled_layout=True,
252
            )
253
254
255
            output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
            return at2[1], output_scale_view

256
257
258
259
260
261
262
263
        def replacement(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            output_quant: torch.Tensor,
            output_scale: torch.Tensor,
            input_scale: torch.Tensor,
264
        ) -> tuple[torch.Tensor, torch.Tensor]:
265
266
267
268
269
            # attention output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size // 2],
                0.0,
                dtype=self.quant_dtype,
270
271
                device=q.device,
            )
272
            # attention output block scale
273
274
275
276
277
278
279
280
281
282
283
284
            output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
            at2 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=input_scale,
                output_block_scale=output_scale_view,
            )
            output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
285
286
            return output, at2[2]

287
288
289
290
291
        inputs = [
            empty_bf16(5, self.num_heads, self.head_size),  # q
            empty_bf16(5, self.num_heads, self.head_size),  # k
            empty_bf16(5, self.num_heads, self.head_size),  # v
            empty_bf16(5, self.num_heads, self.head_size),  # output_attn
292
293
294
295
            self.empty_quant(5, self.num_heads * self.head_size // 2),  # output_quant
            empty_i32(
                128, round_up(self.num_heads * self.head_size // 16, 4)
            ),  # output_scale
296
297
298
299
            empty_fp32(1, 1),  # input_scale
        ]

        pm.register_replacement(
300
301
302
            pattern,
            replacement,
            inputs,
303
            AttentionQuantPattern.wrap_trace_fn(
304
305
306
                pm.fwd_only,
                AttentionQuantPattern.fx_view_to_reshape,
                AttentionQuantPattern.remove_noop_permutes,
307
308
309
            ),
            pm_pass,
        )
310
311


312
class AttnFusionPass(VllmPatternMatcherPass):
313
314
315
316
317
318
319
320
321
322
323
324
    """
    This pass fuses post-attention quantization onto attention if supported.

    It uses the pattern matcher and matches each layer manually, as strings
    cannot be wildcarded. This also lets us check support on attention layers
    upon registration instead of during pattern matching.

    Currently, only static fp8 quant is supported, but patterns could easily be
    added for other quant schemes and dtypes. The bigger hurdle for wider
    support are attention kernels, which need to support fusing output quant.
    """

325
    @enable_fake_mode
326
    def __init__(self, config: VllmConfig) -> None:
327
328
329
330
        super().__init__(config)

        self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")

331
332
        attn_layers = get_layers_from_vllm_config(config, Attention)
        for layer_name, layer in attn_layers.items():
333
            pattern_fp8 = AttentionFp8StaticQuantPattern(
334
335
                layer, config.model_config.dtype
            )
336
337
            pattern_fp8.register_if_supported(self.patterns)

338
            if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
339
                pattern_nvfp4 = AttentionNvfp4QuantPattern(
340
341
                    layer, config.model_config.dtype
                )
342
                pattern_nvfp4.register_if_supported(self.patterns)
343

344
        if len(attn_layers) == 0:
345
            logger.warning(
346
347
                "Attention + quant fusion is enabled, but no attention layers "
                "were found in CompilationConfig.static_forward_context "
348
349
                "so no fusion patterns were registered."
            )
350

351
        self.dump_patterns(config, self.patterns)
352

353
354
355
356
    @VllmInductorPass.time_and_log
    def __call__(self, graph: torch.fx.graph.Graph) -> None:
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Fused quant onto %s attention nodes", self.matched_count)
357

358
    def uuid(self) -> str:
359
360
361
362
363
364
        return VllmInductorPass.hash_source(
            self,
            AttentionQuantPattern,
            AttentionFp8StaticQuantPattern,
            AttentionNvfp4QuantPattern,
        )