attn_quant_fusion.py 9.61 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
from collections.abc import Callable
6

7
8
9
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized

10
from vllm.config import VllmConfig, get_layers_from_vllm_config
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.attention import Attention
13
from vllm.model_executor.layers.quantization.utils.quant_utils import (
14
    QuantKey,
15
    kNvfp4Dynamic,
16
17
    kStaticTensorScale,
)
18
from vllm.platforms import current_platform
19
from vllm.utils.math_utils import round_up
20

21
from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
22
from .matcher_utils import MatcherQuantFP8
23
from .rms_quant_fusion import QUANT_OPS
24
25

logger = init_logger(__name__)
26

27
FP8_DTYPE = current_platform.fp8_dtype()
28
FP4_DTYPE = torch.uint8
29

30
31
32
33
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default


34
_FP8_QUANT_KEY = QuantKey(dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=True)
35

36
37

class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
38
39
40
41
42
43
44
45
46
    """
    Fusion for Attention+Fp8StaticQuant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Fp8StaticQuant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """

47
48
49
50
51
52
53
54
55
56
    def __init__(self, layer: Attention, dtype: torch.dtype):
        self._layer_name = layer.layer_name
        self._num_heads = layer.num_heads
        self._head_size = layer.head_size
        self._dtype = dtype
        self._quant_matcher = MatcherQuantFP8(_FP8_QUANT_KEY)

    @property
    def pattern(self) -> Callable[..., torch.Tensor]:
        def _pattern(
57
58
59
60
61
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            scale: torch.Tensor,
62
            kv_cache_dummy_dep: torch.Tensor,
63
        ) -> torch.Tensor:
64
65
66
67
68
69
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
70
                layer_name=self._layer_name,
71
72
                output_scale=None,
                output_block_scale=None,
73
                kv_cache_dummy_dep=kv_cache_dummy_dep,
74
            )
75
            attn_out_view = RESHAPE_OP(
76
                at1[1], [q.shape[0], self._num_heads * self._head_size]
77
            )
78
            return self._quant_matcher(attn_out_view, scale)[0]
79

80
        return _pattern
81

82
83
84
    @property
    def replacement(self) -> Callable[..., torch.Tensor]:
        def _replacement(
85
86
87
88
89
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            scale: torch.Tensor,
90
            kv_cache_dummy_dep: torch.Tensor,
91
        ) -> torch.Tensor:
92
            output_attn = torch.empty(
93
94
                [q.shape[0], self._num_heads, self._head_size],
                dtype=FP8_DTYPE,
95
96
97
98
99
100
101
102
                device=q.device,
            )
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
103
                layer_name=self._layer_name,
104
105
                output_scale=scale,
                output_block_scale=None,
106
                kv_cache_dummy_dep=kv_cache_dummy_dep,
107
            )
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            return RESHAPE_OP(at1[1], [-1, self._num_heads * self._head_size])

        return _replacement

    def get_inputs(self):
        dtype = self._dtype
        num_heads = self._num_heads
        head_size = self._head_size
        return [
            self.empty(5, num_heads, head_size, dtype=dtype),  # q
            self.empty(5, num_heads, head_size, dtype=dtype),  # k
            self.empty(5, num_heads, head_size, dtype=dtype),  # v
            self.empty(5, num_heads, head_size, dtype=dtype),  # attn_output
            self.empty_fp32(1, 1),  # scale
            self.empty(0, dtype=dtype),  # kv_cache_dummy_dep
123
124
        ]

125

126
127
128
class AttnNvfp4QuantPattern(
    VllmPatternReplacement[..., tuple[torch.Tensor, torch.Tensor]]
):
129
130
131
132
133
134
135
136
    """
    Fusion for Attention+Nvfp4Quant.

    Only triggers when the attention implementation returns True in
    `fused_output_quant_supported()`. If the pattern is found, the
    Nvfp4Quant op will be removed from the graph, and its scale
    will be passed into Attention op as the `output_scale` argument.
    """
137

138
139
140
141
142
143
    def __init__(self, layer: Attention, dtype: torch.dtype):
        self._layer_name = layer.layer_name
        self._num_heads = layer.num_heads
        self._head_size = layer.head_size
        self._dtype = dtype
        self._QUANT_OP = QUANT_OPS[kNvfp4Dynamic]
144

145
146
147
    @property
    def pattern(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
        def _pattern(
148
149
150
151
152
153
154
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
            output_quant: torch.Tensor,
            output_scale: torch.Tensor,
            input_scale: torch.Tensor,
155
            kv_cache_dummy_dep: torch.Tensor,
156
        ) -> tuple[torch.Tensor, torch.Tensor]:
157
158
159
160
161
162
            at1 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
163
                layer_name=self._layer_name,
164
165
                output_scale=None,
                output_block_scale=None,
166
                kv_cache_dummy_dep=kv_cache_dummy_dep,
167
            )
168
            attn_out_view = RESHAPE_OP(
169
                at1[1], [q.shape[0], self._num_heads * self._head_size]
170
171
            )
            at2 = auto_functionalized(
172
                self._QUANT_OP,
173
174
                input=attn_out_view,
                input_scale=input_scale,
175
                is_sf_swizzled_layout=True,
176
177
                output=output_quant,
                output_scale=output_scale,
178
            )
179
180
181
            output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
            return at2[1], output_scale_view

182
183
184
185
186
        return _pattern

    @property
    def replacement(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
        def _replacement(
187
188
189
190
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            output_attn: torch.Tensor,
191
            _output_quant: torch.Tensor,
192
193
            output_scale: torch.Tensor,
            input_scale: torch.Tensor,
194
            kv_cache_dummy_dep: torch.Tensor,
195
        ) -> tuple[torch.Tensor, torch.Tensor]:
196
            output_attn = torch.empty(
197
198
                [q.shape[0], self._num_heads, self._head_size // 2],
                dtype=FP4_DTYPE,
199
200
201
202
203
204
205
206
207
                device=q.device,
            )
            output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
            at2 = auto_functionalized(
                ATTN_OP,
                query=q,
                key=k,
                value=v,
                output=output_attn,
208
                layer_name=self._layer_name,
209
210
                output_scale=input_scale,
                output_block_scale=output_scale_view,
211
                kv_cache_dummy_dep=kv_cache_dummy_dep,
212
            )
213
            output = RESHAPE_OP(at2[1], [-1, self._num_heads * self._head_size // 2])
214
215
            return output, at2[2]

216
217
218
219
220
221
222
223
224
225
226
227
228
229
        return _replacement

    def get_inputs(self):
        dtype = self._dtype
        num_heads = self._num_heads
        head_size = self._head_size
        return [
            self.empty_bf16(5, num_heads, head_size),  # q
            self.empty_bf16(5, num_heads, head_size),  # k
            self.empty_bf16(5, num_heads, head_size),  # v
            self.empty_bf16(5, num_heads, head_size),  # output_attn
            self.empty(5, num_heads * head_size // 2, dtype=FP4_DTYPE),  # output_quant
            self.empty_i32(
                128, round_up(num_heads * head_size // 16, 4)
230
            ),  # output_scale
231
232
            self.empty_fp32(1, 1),  # input_scale
            self.empty(0, dtype=dtype),  # kv_cache_dummy_dep
233
234
        ]

235

236
class AttnQuantFusionPass(VllmFusionPatternMatcherPass):
237
238
239
240
241
242
243
244
245
246
247
248
    """
    This pass fuses post-attention quantization onto attention if supported.

    It uses the pattern matcher and matches each layer manually, as strings
    cannot be wildcarded. This also lets us check support on attention layers
    upon registration instead of during pattern matching.

    Currently, only static fp8 quant is supported, but patterns could easily be
    added for other quant schemes and dtypes. The bigger hurdle for wider
    support are attention kernels, which need to support fusing output quant.
    """

249
    def __init__(self, config: VllmConfig) -> None:
250
        super().__init__(config, "attn_quant_fusion")
251

252
253
        dtype = config.model_config.dtype
        layers = list(get_layers_from_vllm_config(config, Attention).values())
254

255
        if len(layers) == 0:
256
            logger.warning(
257
258
                "Attention + quant fusion is enabled, but no attention layers "
                "were found in CompilationConfig.static_forward_context "
259
260
                "so no fusion patterns were registered."
            )
261

262
263
264
        for layer in layers:
            if layer.impl.fused_output_quant_supported(_FP8_QUANT_KEY):
                self.register(AttnFp8StaticQuantPattern(layer, dtype))
265

266
267
268
269
        if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
            for layer in layers:
                if layer.impl.fused_output_quant_supported(kNvfp4Dynamic):
                    self.register(AttnNvfp4QuantPattern(layer, dtype))
270

271
        self.dump_patterns(config, self.pm_pass)