fusion_attn.py 11.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from abc import ABC, abstractmethod

6
7
8
9
10
11
import torch
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.attention import Attention
12
from vllm.config import VllmConfig, get_layers_from_vllm_config
13
from vllm.logger import init_logger
14
15
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    QuantKey, kNvfp4Quant, kStaticTensorScale)
16
from vllm.platforms import current_platform
17
from vllm.utils import round_up
18

19
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
20
from .inductor_pass import enable_fake_mode
21
22
23
24
from .vllm_inductor_pass import VllmInductorPass

logger = init_logger(__name__)

25
FP8_DTYPE = current_platform.fp8_dtype()
26
FP4_DTYPE = torch.uint8
27

28
29
30
31
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default


32
class AttentionQuantPattern(ABC):
33
    """
34
35
    The base class for Attn+Quant fusions.
    Should not be used directly.
36
    """
37
38
39

    def __init__(
        self,
40
        layer: Attention,
41
        quant_key: QuantKey,
42
    ):
43
44
45
46
        self.layer = layer
        self.layer_name = layer.layer_name
        self.num_heads = layer.num_heads
        self.head_size = layer.head_size
47
48
49
        self.quant_key = quant_key
        self.quant_dtype = quant_key.dtype

50
51
52
53
54
55
56
57
        assert self.quant_key in QUANT_OPS, \
            f"unsupported quantization scheme {self.quant_key}"
        self.QUANT_OP = QUANT_OPS[self.quant_key]

    def empty_quant(self, *args, **kwargs):
        kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
        return torch.empty(*args, **kwargs)

58
59
60
61
62
63
64
65
66
67
68
69
70
71
    @staticmethod
    def wrap_trace_fn(process_fx, trace_fn):

        def wrapped(*args, **kwargs):
            return process_fx(trace_fn(*args, **kwargs))

        return wrapped

    @staticmethod
    def fx_view_to_reshape(gm: torch.fx.GraphModule):
        from torch._inductor.fx_passes.post_grad import view_to_reshape
        view_to_reshape(gm)
        return gm

72
    def register_if_supported(self, pm_pass: PatternMatcherPass):
73
        if self.layer.impl.fused_output_quant_supported(self.quant_key):
74
75
            self._register(pm_pass)

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    @abstractmethod
    def _register(self, pm_pass: PatternMatcherPass):
        raise NotImplementedError


class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Fp8StaticQuant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Fp8StaticQuant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """

    def __init__(
        self,
        layer: Attention,
        symmetric: bool = True,
    ):
        quant_key = QuantKey(dtype=FP8_DTYPE,
                             scale=kStaticTensorScale,
                             symmetric=symmetric)
        super().__init__(layer, quant_key)

101
102
103
104
105
106
107
108
109
    def _register(self, pm_pass: PatternMatcherPass):

        def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                    output_attn: torch.Tensor, output_quant: torch.Tensor,
                    scale: torch.Tensor):
            at1 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
110
                                      output=output_attn,
111
                                      layer_name=self.layer_name,
112
113
114
115
                                      output_scale=None,
                                      output_block_scale=None)
            attn_out_view = RESHAPE_OP(
                at1[1], [q.shape[0], self.num_heads * self.head_size])
116
117
118
119
120
121
122
123
124
            at2 = auto_functionalized(self.QUANT_OP,
                                      result=output_quant,
                                      input=attn_out_view,
                                      scale=scale)
            return at2[1]

        def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                        output_attn: torch.Tensor, output_quant: torch.Tensor,
                        scale: torch.Tensor):
125
126
127
128
129
130
            # attn output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size],
                0.0,
                dtype=self.quant_dtype,
                device=q.device)
131
132
133
134
            at1 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
135
                                      output=output_attn,
136
                                      layer_name=self.layer_name,
137
138
                                      output_scale=scale,
                                      output_block_scale=None)
139
140
            return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        inputs = [
            empty_bf16(5, self.num_heads, self.head_size),  # q
            empty_bf16(5, self.num_heads, self.head_size),  # k
            empty_bf16(5, self.num_heads, self.head_size),  # v
            empty_bf16(5, self.num_heads, self.head_size),  # attn_output
            self.empty_quant(5,
                             self.num_heads * self.head_size),  # quant_output
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(
            pattern, replacement, inputs,
            AttentionQuantPattern.wrap_trace_fn(
                AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
            pm_pass)
156

157

158
159
160
161
162
163
164
165
166
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Nvfp4Quant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Nvfp4Quant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """
167

168
169
    def __init__(self, layer: Attention):
        super().__init__(layer, kNvfp4Quant)
170

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    def _register(self, pm_pass: PatternMatcherPass):

        def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                    output_attn: torch.Tensor, output_quant: torch.Tensor,
                    output_scale: torch.Tensor, input_scale: torch.Tensor):
            at1 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
                                      output=output_attn,
                                      layer_name=self.layer_name,
                                      output_scale=None,
                                      output_block_scale=None)
            attn_out_view = RESHAPE_OP(
                at1[1], [q.shape[0], self.num_heads * self.head_size])
            at2 = auto_functionalized(self.QUANT_OP,
                                      output=output_quant,
                                      input=attn_out_view,
                                      output_scale=output_scale,
                                      input_scale=input_scale)
            output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
            return at2[1], output_scale_view

        def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                        output_attn: torch.Tensor, output_quant: torch.Tensor,
                        output_scale: torch.Tensor, input_scale: torch.Tensor):
            # attention output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size // 2],
                0.0,
                dtype=self.quant_dtype,
                device=q.device)
            # attention output block scale
            output_scale_view = torch.ops.aten.view.dtype(
                output_scale, FP8_DTYPE)
            at2 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
                                      output=output_attn,
                                      layer_name=self.layer_name,
                                      output_scale=input_scale,
                                      output_block_scale=output_scale_view)
            output = RESHAPE_OP(at2[1],
                                [-1, self.num_heads * self.head_size // 2])
            return output, at2[2]

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        inputs = [
            empty_bf16(5, self.num_heads, self.head_size),  # q
            empty_bf16(5, self.num_heads, self.head_size),  # k
            empty_bf16(5, self.num_heads, self.head_size),  # v
            empty_bf16(5, self.num_heads, self.head_size),  # output_attn
            self.empty_quant(5, self.num_heads * self.head_size //
                             2),  # output_quant
            empty_i32(128, round_up(self.num_heads * self.head_size // 16,
                                    4)),  # output_scale
            empty_fp32(1, 1),  # input_scale
        ]

        pm.register_replacement(
            pattern, replacement, inputs,
            AttentionQuantPattern.wrap_trace_fn(
                AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
            pm_pass)
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249


class AttnFusionPass(VllmInductorPass):
    """
    This pass fuses post-attention quantization onto attention if supported.

    It uses the pattern matcher and matches each layer manually, as strings
    cannot be wildcarded. This also lets us check support on attention layers
    upon registration instead of during pattern matching.

    Currently, only static fp8 quant is supported, but patterns could easily be
    added for other quant schemes and dtypes. The bigger hurdle for wider
    support are attention kernels, which need to support fusing output quant.
    """

250
    @enable_fake_mode
251
252
253
254
255
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")

256
257
        attn_layers = get_layers_from_vllm_config(config, Attention)
        for layer_name, layer in attn_layers.items():
258
259
260
261
262
263
            pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
            pattern_fp8.register_if_supported(self.patterns)

            pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
            pattern_nvfp4.register_if_supported(self.patterns)

264
        if len(attn_layers) == 0:
265
            logger.warning(
266
267
268
                "Attention + quant fusion is enabled, but no attention layers "
                "were found in CompilationConfig.static_forward_context "
                "so no fusion patterns were registered.")
269
270
271
272
273
274

    def __call__(self, graph: torch.fx.graph.Graph) -> None:
        self.begin()
        self.dump_graph(graph, "before_attn_fusion")

        count = self.patterns.apply(graph)
275
276
277
278
279
280

        # TODO: Move this to pass_manager.py after the fx graph broken issue
        # has been resolved.
        # see https://github.com/vllm-project/vllm/issues/23091
        graph.eliminate_dead_code()

281
282
283
        logger.debug("Fused quantization onto %s attention nodes", count)
        self.dump_graph(graph, "after_attn_fusion")
        self.end_and_log()
284
285

    def uuid(self):
286
287
288
        return VllmInductorPass.hash_source(self, AttentionQuantPattern,
                                            AttentionFp8StaticQuantPattern,
                                            AttentionNvfp4QuantPattern)