fusion_attn.py 11.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from abc import ABC, abstractmethod

6
7
8
9
10
11
import torch
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.attention import Attention
12
from vllm.config import VllmConfig, get_layers_from_vllm_config
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
16
17
18
    QuantKey,
    kNvfp4Quant,
    kStaticTensorScale,
)
19
from vllm.platforms import current_platform
20
from vllm.utils import round_up
21

22
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
23
from .inductor_pass import enable_fake_mode
24
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
25
26
27

logger = init_logger(__name__)

28
FP8_DTYPE = current_platform.fp8_dtype()
29
FP4_DTYPE = torch.uint8
30

31
32
33
34
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default


35
class AttentionQuantPattern(ABC):
36
    """
37
38
    The base class for Attn+Quant fusions.
    Should not be used directly.
39
    """
40
41
42

    def __init__(
        self,
43
        layer: Attention,
44
        quant_key: QuantKey,
45
        dtype: torch.dtype,
46
    ):
47
48
49
50
        self.layer = layer
        self.layer_name = layer.layer_name
        self.num_heads = layer.num_heads
        self.head_size = layer.head_size
51
52
        self.quant_key = quant_key
        self.quant_dtype = quant_key.dtype
53
        self.dtype = dtype
54

55
        assert self.quant_key in QUANT_OPS, (
56
            f"unsupported quantization scheme {self.quant_key}"
57
        )
58
59
        self.QUANT_OP = QUANT_OPS[self.quant_key]

60
    def empty(self, *args, **kwargs):
61
        kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
62
63
        return torch.empty(*args, **kwargs)

64
    def empty_quant(self, *args, **kwargs):
65
        kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
66
67
        return torch.empty(*args, **kwargs)

68
69
70
71
72
73
74
75
76
77
    @staticmethod
    def wrap_trace_fn(process_fx, trace_fn):
        def wrapped(*args, **kwargs):
            return process_fx(trace_fn(*args, **kwargs))

        return wrapped

    @staticmethod
    def fx_view_to_reshape(gm: torch.fx.GraphModule):
        from torch._inductor.fx_passes.post_grad import view_to_reshape
78

79
80
81
        view_to_reshape(gm)
        return gm

82
    def register_if_supported(self, pm_pass: PatternMatcherPass):
83
        if self.layer.impl.fused_output_quant_supported(self.quant_key):
84
85
            self._register(pm_pass)

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    @abstractmethod
    def _register(self, pm_pass: PatternMatcherPass):
        raise NotImplementedError


class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Fp8StaticQuant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Fp8StaticQuant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """

    def __init__(
        self,
        layer: Attention,
104
        dtype: torch.dtype,
105
106
        symmetric: bool = True,
    ):
107
108
109
        quant_key = QuantKey(
            dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
        )
110
        super().__init__(layer, quant_key, dtype)
111

112
    def _register(self, pm_pass: PatternMatcherPass):
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        def pattern(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            output_quant: torch.Tensor,
            scale: torch.Tensor,
        ):
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=None,
                output_block_scale=None,
            )
131
            attn_out_view = RESHAPE_OP(
132
133
134
135
136
                at1[1], [q.shape[0], self.num_heads * self.head_size]
            )
            at2 = auto_functionalized(
                self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale
            )
137
138
            return at2[1]

139
140
141
142
143
144
145
146
        def replacement(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            output_quant: torch.Tensor,
            scale: torch.Tensor,
        ):
147
148
149
150
151
            # attn output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size],
                0.0,
                dtype=self.quant_dtype,
152
153
154
155
156
157
158
159
160
161
162
163
                device=q.device,
            )
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=scale,
                output_block_scale=None,
            )
164
165
            return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])

166
        inputs = [
167
168
169
170
171
172
173
174
            self.empty(5, self.num_heads, self.head_size, dtype=self.dtype),  # q
            self.empty(5, self.num_heads, self.head_size, dtype=self.dtype),  # k
            self.empty(5, self.num_heads, self.head_size, dtype=self.dtype),  # v
            self.empty(
                5, self.num_heads, self.head_size, dtype=self.dtype
            ),  # attn_output
            self.empty_quant(5, self.num_heads * self.head_size),  # quant_output
            empty_fp32(1, 1),  # scale
175
176
177
        ]

        pm.register_replacement(
178
179
180
            pattern,
            replacement,
            inputs,
181
            AttentionQuantPattern.wrap_trace_fn(
182
183
184
185
                AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
            ),
            pm_pass,
        )
186

187

188
189
190
191
192
193
194
195
196
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Nvfp4Quant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Nvfp4Quant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """
197

198
199
    def __init__(self, layer: Attention, dtype: torch.dtype):
        super().__init__(layer, kNvfp4Quant, dtype)
200

201
    def _register(self, pm_pass: PatternMatcherPass):
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        def pattern(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            output_quant: torch.Tensor,
            output_scale: torch.Tensor,
            input_scale: torch.Tensor,
        ):
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=None,
                output_block_scale=None,
            )
221
            attn_out_view = RESHAPE_OP(
222
223
224
225
226
227
228
229
230
                at1[1], [q.shape[0], self.num_heads * self.head_size]
            )
            at2 = auto_functionalized(
                self.QUANT_OP,
                output=output_quant,
                input=attn_out_view,
                output_scale=output_scale,
                input_scale=input_scale,
            )
231
232
233
            output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
            return at2[1], output_scale_view

234
235
236
237
238
239
240
241
242
        def replacement(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            output_quant: torch.Tensor,
            output_scale: torch.Tensor,
            input_scale: torch.Tensor,
        ):
243
244
245
246
247
            # attention output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size // 2],
                0.0,
                dtype=self.quant_dtype,
248
249
                device=q.device,
            )
250
            # attention output block scale
251
252
253
254
255
256
257
258
259
260
261
262
            output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
            at2 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
                layer_name=self.layer_name,
                output_scale=input_scale,
                output_block_scale=output_scale_view,
            )
            output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
263
264
            return output, at2[2]

265
266
267
268
269
        inputs = [
            empty_bf16(5, self.num_heads, self.head_size),  # q
            empty_bf16(5, self.num_heads, self.head_size),  # k
            empty_bf16(5, self.num_heads, self.head_size),  # v
            empty_bf16(5, self.num_heads, self.head_size),  # output_attn
270
271
272
273
            self.empty_quant(5, self.num_heads * self.head_size // 2),  # output_quant
            empty_i32(
                128, round_up(self.num_heads * self.head_size // 16, 4)
            ),  # output_scale
274
275
276
277
            empty_fp32(1, 1),  # input_scale
        ]

        pm.register_replacement(
278
279
280
            pattern,
            replacement,
            inputs,
281
            AttentionQuantPattern.wrap_trace_fn(
282
283
284
285
                AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
            ),
            pm_pass,
        )
286
287


288
class AttnFusionPass(VllmPatternMatcherPass):
289
290
291
292
293
294
295
296
297
298
299
300
    """
    This pass fuses post-attention quantization onto attention if supported.

    It uses the pattern matcher and matches each layer manually, as strings
    cannot be wildcarded. This also lets us check support on attention layers
    upon registration instead of during pattern matching.

    Currently, only static fp8 quant is supported, but patterns could easily be
    added for other quant schemes and dtypes. The bigger hurdle for wider
    support are attention kernels, which need to support fusing output quant.
    """

301
    @enable_fake_mode
302
303
304
305
306
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")

307
308
        attn_layers = get_layers_from_vllm_config(config, Attention)
        for layer_name, layer in attn_layers.items():
309
            pattern_fp8 = AttentionFp8StaticQuantPattern(
310
311
                layer, config.model_config.dtype
            )
312
313
            pattern_fp8.register_if_supported(self.patterns)

314
            if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
315
                pattern_nvfp4 = AttentionNvfp4QuantPattern(
316
317
                    layer, config.model_config.dtype
                )
318
                pattern_nvfp4.register_if_supported(self.patterns)
319

320
        if len(attn_layers) == 0:
321
            logger.warning(
322
323
                "Attention + quant fusion is enabled, but no attention layers "
                "were found in CompilationConfig.static_forward_context "
324
325
                "so no fusion patterns were registered."
            )
326

327
        self.dump_patterns(config, self.patterns)
328

329
330
331
332
    @VllmInductorPass.time_and_log
    def __call__(self, graph: torch.fx.graph.Graph) -> None:
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Fused quant onto %s attention nodes", self.matched_count)
333
334

    def uuid(self):
335
336
337
338
339
340
        return VllmInductorPass.hash_source(
            self,
            AttentionQuantPattern,
            AttentionFp8StaticQuantPattern,
            AttentionNvfp4QuantPattern,
        )