fusion_attn.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from abc import ABC, abstractmethod

6
7
8
9
10
11
12
13
import torch
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._subclasses.fake_tensor import (FakeTensorMode,
                                           unset_fake_temporarily)

from vllm.attention import Attention
14
from vllm.config import VllmConfig, get_layers_from_vllm_config
15
from vllm.logger import init_logger
16
17
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    QuantKey, kNvfp4Quant, kStaticTensorScale)
18
from vllm.platforms import current_platform
19
from vllm.utils import round_up
20

21
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
22
23
24
25
from .vllm_inductor_pass import VllmInductorPass

logger = init_logger(__name__)

26
FP8_DTYPE = current_platform.fp8_dtype()
27
FP4_DTYPE = torch.uint8
28

29
30
31
32
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default


33
class AttentionQuantPattern(ABC):
34
    """
35
36
    The base class for Attn+Quant fusions.
    Should not be used directly.
37
    """
38
39
40

    def __init__(
        self,
41
        layer: Attention,
42
        quant_key: QuantKey,
43
    ):
44
45
46
47
        self.layer = layer
        self.layer_name = layer.layer_name
        self.num_heads = layer.num_heads
        self.head_size = layer.head_size
48
49
50
        self.quant_key = quant_key
        self.quant_dtype = quant_key.dtype

51
52
53
54
55
56
57
58
        assert self.quant_key in QUANT_OPS, \
            f"unsupported quantization scheme {self.quant_key}"
        self.QUANT_OP = QUANT_OPS[self.quant_key]

    def empty_quant(self, *args, **kwargs):
        kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
        return torch.empty(*args, **kwargs)

59
60
61
62
63
64
65
66
67
68
69
70
71
72
    @staticmethod
    def wrap_trace_fn(process_fx, trace_fn):

        def wrapped(*args, **kwargs):
            return process_fx(trace_fn(*args, **kwargs))

        return wrapped

    @staticmethod
    def fx_view_to_reshape(gm: torch.fx.GraphModule):
        from torch._inductor.fx_passes.post_grad import view_to_reshape
        view_to_reshape(gm)
        return gm

73
    def register_if_supported(self, pm_pass: PatternMatcherPass):
74
        if self.layer.impl.fused_output_quant_supported(self.quant_key):
75
76
            self._register(pm_pass)

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    @abstractmethod
    def _register(self, pm_pass: PatternMatcherPass):
        raise NotImplementedError


class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Fp8StaticQuant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Fp8StaticQuant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """

    def __init__(
        self,
        layer: Attention,
        symmetric: bool = True,
    ):
        quant_key = QuantKey(dtype=FP8_DTYPE,
                             scale=kStaticTensorScale,
                             symmetric=symmetric)
        super().__init__(layer, quant_key)

102
103
104
105
106
107
108
109
110
    def _register(self, pm_pass: PatternMatcherPass):

        def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                    output_attn: torch.Tensor, output_quant: torch.Tensor,
                    scale: torch.Tensor):
            at1 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
111
                                      output=output_attn,
112
                                      layer_name=self.layer_name,
113
114
115
116
                                      output_scale=None,
                                      output_block_scale=None)
            attn_out_view = RESHAPE_OP(
                at1[1], [q.shape[0], self.num_heads * self.head_size])
117
118
119
120
121
122
123
124
125
            at2 = auto_functionalized(self.QUANT_OP,
                                      result=output_quant,
                                      input=attn_out_view,
                                      scale=scale)
            return at2[1]

        def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                        output_attn: torch.Tensor, output_quant: torch.Tensor,
                        scale: torch.Tensor):
126
127
128
129
130
131
            # attn output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size],
                0.0,
                dtype=self.quant_dtype,
                device=q.device)
132
133
134
135
            at1 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
136
                                      output=output_attn,
137
                                      layer_name=self.layer_name,
138
139
                                      output_scale=scale,
                                      output_block_scale=None)
140
141
142
143
144
145
146
147
148
            return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])

        # Need custom fake mode, otherwise tracing happens with real tensors.
        # That would not work for the unified_attention custom op.
        with unset_fake_temporarily(), FakeTensorMode():
            inputs = [
                empty_bf16(5, self.num_heads, self.head_size),  # q
                empty_bf16(5, self.num_heads, self.head_size),  # k
                empty_bf16(5, self.num_heads, self.head_size),  # v
149
                empty_bf16(5, self.num_heads, self.head_size),  # attn_output
150
151
152
153
154
                self.empty_quant(5, self.num_heads *
                                 self.head_size),  # quant_output
                empty_fp32(1, 1)  # scale
            ]

155
156
157
158
159
160
            pm.register_replacement(
                pattern, replacement, inputs,
                AttentionQuantPattern.wrap_trace_fn(
                    AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
                pm_pass)

161

162
163
164
165
166
167
168
169
170
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
    """
    Fusion for Attention+Nvfp4Quant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Nvfp4Quant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """
171

172
173
    def __init__(self, layer: Attention):
        super().__init__(layer, kNvfp4Quant)
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    def _register(self, pm_pass: PatternMatcherPass):

        def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                    output_attn: torch.Tensor, output_quant: torch.Tensor,
                    output_scale: torch.Tensor, input_scale: torch.Tensor):
            at1 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
                                      output=output_attn,
                                      layer_name=self.layer_name,
                                      output_scale=None,
                                      output_block_scale=None)
            attn_out_view = RESHAPE_OP(
                at1[1], [q.shape[0], self.num_heads * self.head_size])
            at2 = auto_functionalized(self.QUANT_OP,
                                      output=output_quant,
                                      input=attn_out_view,
                                      output_scale=output_scale,
                                      input_scale=input_scale)
            output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
            return at2[1], output_scale_view

        def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                        output_attn: torch.Tensor, output_quant: torch.Tensor,
                        output_scale: torch.Tensor, input_scale: torch.Tensor):
            # attention output in quant_dtype
            output_attn = torch.ops.aten.full.default(
                [q.shape[0], self.num_heads, self.head_size // 2],
                0.0,
                dtype=self.quant_dtype,
                device=q.device)
            # attention output block scale
            output_scale_view = torch.ops.aten.view.dtype(
                output_scale, FP8_DTYPE)
            at2 = auto_functionalized(ATTN_OP,
                                      query=q,
                                      key=k,
                                      value=v,
                                      output=output_attn,
                                      layer_name=self.layer_name,
                                      output_scale=input_scale,
                                      output_block_scale=output_scale_view)
            output = RESHAPE_OP(at2[1],
                                [-1, self.num_heads * self.head_size // 2])
            return output, at2[2]

        # Need custom fake mode, otherwise tracing happens with real tensors.
        # That would not work for the unified_attention custom op.
        with unset_fake_temporarily(), FakeTensorMode():
            inputs = [
                empty_bf16(5, self.num_heads, self.head_size),  # q
                empty_bf16(5, self.num_heads, self.head_size),  # k
                empty_bf16(5, self.num_heads, self.head_size),  # v
                empty_bf16(5, self.num_heads, self.head_size),  # output_attn
                self.empty_quant(5, self.num_heads * self.head_size //
                                 2),  # output_quant
                empty_i32(128,
                          round_up(self.num_heads * self.head_size // 16,
                                   4)),  # output_scale
                empty_fp32(1, 1),  # input_scale
            ]
237
238
239

            pm.register_replacement(
                pattern, replacement, inputs,
240
241
242
                AttentionQuantPattern.wrap_trace_fn(
                    AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
                pm_pass)
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262


class AttnFusionPass(VllmInductorPass):
    """
    This pass fuses post-attention quantization onto attention if supported.

    It uses the pattern matcher and matches each layer manually, as strings
    cannot be wildcarded. This also lets us check support on attention layers
    upon registration instead of during pattern matching.

    Currently, only static fp8 quant is supported, but patterns could easily be
    added for other quant schemes and dtypes. The bigger hurdle for wider
    support are attention kernels, which need to support fusing output quant.
    """

    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")

263
264
        attn_layers = get_layers_from_vllm_config(config, Attention)
        for layer_name, layer in attn_layers.items():
265
266
267
268
269
270
            pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
            pattern_fp8.register_if_supported(self.patterns)

            pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
            pattern_nvfp4.register_if_supported(self.patterns)

271
        if len(attn_layers) == 0:
272
            logger.warning(
273
274
275
                "Attention + quant fusion is enabled, but no attention layers "
                "were found in CompilationConfig.static_forward_context "
                "so no fusion patterns were registered.")
276
277
278
279
280
281

    def __call__(self, graph: torch.fx.graph.Graph) -> None:
        self.begin()
        self.dump_graph(graph, "before_attn_fusion")

        count = self.patterns.apply(graph)
282
283
284
285
286
287

        # TODO: Move this to pass_manager.py after the fx graph broken issue
        # has been resolved.
        # see https://github.com/vllm-project/vllm/issues/23091
        graph.eliminate_dead_code()

288
289
290
        logger.debug("Fused quantization onto %s attention nodes", count)
        self.dump_graph(graph, "after_attn_fusion")
        self.end_and_log()
291
292

    def uuid(self):
293
294
295
        return VllmInductorPass.hash_source(self, AttentionQuantPattern,
                                            AttentionFp8StaticQuantPattern,
                                            AttentionNvfp4QuantPattern)