fusion_attn.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from abc import ABC, abstractmethod

6
7
8
9
10
11
import torch
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.attention import Attention
12
from vllm.config import VllmConfig, get_layers_from_vllm_config
13
from vllm.logger import init_logger
14
15
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    QuantKey, kNvfp4Quant, kStaticTensorScale)
16
from vllm.platforms import current_platform
17
from vllm.utils import round_up
18

19
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
20
from .inductor_pass import enable_fake_mode
21
22
23
24
from .vllm_inductor_pass import VllmInductorPass

logger = init_logger(__name__)

25
FP8_DTYPE = current_platform.fp8_dtype()
26
FP4_DTYPE = torch.uint8
27

28
29
30
31
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default


32
class AttentionQuantPattern(ABC):
33
    """
34
35
    The base class for Attn+Quant fusions.
    Should not be used directly.
36
    """
37
38
39

    def __init__(
        self,
40
        layer: Attention,
41
        quant_key: QuantKey,
42
        dtype: torch.dtype,
43
    ):
44
45
46
47
        self.layer = layer
        self.layer_name = layer.layer_name
        self.num_heads = layer.num_heads
        self.head_size = layer.head_size
48
49
        self.quant_key = quant_key
        self.quant_dtype = quant_key.dtype
50
        self.dtype = dtype
51

52
53
54
55
        assert self.quant_key in QUANT_OPS, \
            f"unsupported quantization scheme {self.quant_key}"
        self.QUANT_OP = QUANT_OPS[self.quant_key]

56
57
58
59
    def empty(self, *args, **kwargs):
        kwargs = {'dtype': self.dtype, 'device': "cuda", **kwargs}
        return torch.empty(*args, **kwargs)

60
61
62
63
    def empty_quant(self, *args, **kwargs):
        kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
        return torch.empty(*args, **kwargs)

64
65
66
67
68
69
70
71
72
73
74
75
76
77
    @staticmethod
    def wrap_trace_fn(process_fx, trace_fn):

        def wrapped(*args, **kwargs):
            return process_fx(trace_fn(*args, **kwargs))

        return wrapped

    @staticmethod
    def fx_view_to_reshape(gm: torch.fx.GraphModule):
        from torch._inductor.fx_passes.post_grad import view_to_reshape
        view_to_reshape(gm)
        return gm

78
    def register_if_supported(self, pm_pass: PatternMatcherPass):
79
        if self.layer.impl.fused_output_quant_supported(self.quant_key):
80
81
            self._register(pm_pass)

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    @abstractmethod
    def _register(self, pm_pass: PatternMatcherPass):
        raise NotImplementedError


class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Fp8StaticQuant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Fp8StaticQuant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """

    def __init__(
        self,
        layer: Attention,
100
        dtype: torch.dtype,
101
102
103
104
105
        symmetric: bool = True,
    ):
        quant_key = QuantKey(dtype=FP8_DTYPE,
                             scale=kStaticTensorScale,
                             symmetric=symmetric)
106
        super().__init__(layer, quant_key, dtype)
107

108
109
110
111
112
113
114
115
116
    def _register(self, pm_pass: PatternMatcherPass):

        def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                    output_attn: torch.Tensor, output_quant: torch.Tensor,
                    scale: torch.Tensor):
            at1 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
117
                                      output=output_attn,
118
                                      layer_name=self.layer_name,
119
120
121
122
                                      output_scale=None,
                                      output_block_scale=None)
            attn_out_view = RESHAPE_OP(
                at1[1], [q.shape[0], self.num_heads * self.head_size])
123
124
125
126
127
128
129
130
131
            at2 = auto_functionalized(self.QUANT_OP,
                                      result=output_quant,
                                      input=attn_out_view,
                                      scale=scale)
            return at2[1]

        def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                        output_attn: torch.Tensor, output_quant: torch.Tensor,
                        scale: torch.Tensor):
132
133
134
135
136
137
            # attn output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size],
                0.0,
                dtype=self.quant_dtype,
                device=q.device)
138
139
140
141
            at1 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
142
                                      output=output_attn,
143
                                      layer_name=self.layer_name,
144
145
                                      output_scale=scale,
                                      output_block_scale=None)
146
147
            return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])

148
        inputs = [
149
150
151
152
153
154
155
156
            self.empty(5, self.num_heads, self.head_size,
                       dtype=self.dtype),  # q
            self.empty(5, self.num_heads, self.head_size,
                       dtype=self.dtype),  # k
            self.empty(5, self.num_heads, self.head_size,
                       dtype=self.dtype),  # v
            self.empty(5, self.num_heads, self.head_size,
                       dtype=self.dtype),  # attn_output
157
158
159
160
161
162
163
164
165
166
            self.empty_quant(5,
                             self.num_heads * self.head_size),  # quant_output
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(
            pattern, replacement, inputs,
            AttentionQuantPattern.wrap_trace_fn(
                AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
            pm_pass)
167

168

169
170
171
172
173
174
175
176
177
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Nvfp4Quant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Nvfp4Quant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """
178

179
180
    def __init__(self, layer: Attention, dtype: torch.dtype):
        super().__init__(layer, kNvfp4Quant, dtype)
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    def _register(self, pm_pass: PatternMatcherPass):

        def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                    output_attn: torch.Tensor, output_quant: torch.Tensor,
                    output_scale: torch.Tensor, input_scale: torch.Tensor):
            at1 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
                                      output=output_attn,
                                      layer_name=self.layer_name,
                                      output_scale=None,
                                      output_block_scale=None)
            attn_out_view = RESHAPE_OP(
                at1[1], [q.shape[0], self.num_heads * self.head_size])
            at2 = auto_functionalized(self.QUANT_OP,
                                      output=output_quant,
                                      input=attn_out_view,
                                      output_scale=output_scale,
                                      input_scale=input_scale)
            output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
            return at2[1], output_scale_view

        def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                        output_attn: torch.Tensor, output_quant: torch.Tensor,
                        output_scale: torch.Tensor, input_scale: torch.Tensor):
            # attention output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size // 2],
                0.0,
                dtype=self.quant_dtype,
                device=q.device)
            # attention output block scale
            output_scale_view = torch.ops.aten.view.dtype(
                output_scale, FP8_DTYPE)
            at2 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
                                      output=output_attn,
                                      layer_name=self.layer_name,
                                      output_scale=input_scale,
                                      output_block_scale=output_scale_view)
            output = RESHAPE_OP(at2[1],
                                [-1, self.num_heads * self.head_size // 2])
            return output, at2[2]

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        inputs = [
            empty_bf16(5, self.num_heads, self.head_size),  # q
            empty_bf16(5, self.num_heads, self.head_size),  # k
            empty_bf16(5, self.num_heads, self.head_size),  # v
            empty_bf16(5, self.num_heads, self.head_size),  # output_attn
            self.empty_quant(5, self.num_heads * self.head_size //
                             2),  # output_quant
            empty_i32(128, round_up(self.num_heads * self.head_size // 16,
                                    4)),  # output_scale
            empty_fp32(1, 1),  # input_scale
        ]

        pm.register_replacement(
            pattern, replacement, inputs,
            AttentionQuantPattern.wrap_trace_fn(
                AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
            pm_pass)
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260


class AttnFusionPass(VllmInductorPass):
    """
    This pass fuses post-attention quantization onto attention if supported.

    It uses the pattern matcher and matches each layer manually, as strings
    cannot be wildcarded. This also lets us check support on attention layers
    upon registration instead of during pattern matching.

    Currently, only static fp8 quant is supported, but patterns could easily be
    added for other quant schemes and dtypes. The bigger hurdle for wider
    support are attention kernels, which need to support fusing output quant.
    """

261
    @enable_fake_mode
262
263
264
265
266
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")

267
268
        attn_layers = get_layers_from_vllm_config(config, Attention)
        for layer_name, layer in attn_layers.items():
269
270
            pattern_fp8 = AttentionFp8StaticQuantPattern(
                layer, config.model_config.dtype)
271
272
            pattern_fp8.register_if_supported(self.patterns)

273
274
            if current_platform.is_cuda() and hasattr(torch.ops._C,
                                                      "scaled_fp4_quant"):
275
276
                pattern_nvfp4 = AttentionNvfp4QuantPattern(
                    layer, config.model_config.dtype)
277
                pattern_nvfp4.register_if_supported(self.patterns)
278

279
        if len(attn_layers) == 0:
280
            logger.warning(
281
282
283
                "Attention + quant fusion is enabled, but no attention layers "
                "were found in CompilationConfig.static_forward_context "
                "so no fusion patterns were registered.")
284
285
286
287
288
289

    def __call__(self, graph: torch.fx.graph.Graph) -> None:
        self.begin()
        self.dump_graph(graph, "before_attn_fusion")

        count = self.patterns.apply(graph)
290
291
292
293
294
295

        # TODO: Move this to pass_manager.py after the fx graph broken issue
        # has been resolved.
        # see https://github.com/vllm-project/vllm/issues/23091
        graph.eliminate_dead_code()

296
297
298
        logger.debug("Fused quantization onto %s attention nodes", count)
        self.dump_graph(graph, "after_attn_fusion")
        self.end_and_log()
299
300

    def uuid(self):
301
302
303
        return VllmInductorPass.hash_source(self, AttentionQuantPattern,
                                            AttentionFp8StaticQuantPattern,
                                            AttentionNvfp4QuantPattern)