matcher_utils.py 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod

import torch
from torch._higher_order_ops import auto_functionalized
from torch._ops import OpOverload

9
from vllm._aiter_ops import rocm_aiter_ops
10
from vllm.config import get_current_vllm_config
11
from vllm.model_executor.layers.activation import SiluAndMul
12
13
14
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
    GroupShape,
16
17
    QuantKey,
    _normalize_quant_group_shape,
18
19
    kFp8Dynamic64Sym,
    kFp8Dynamic128Sym,
20
21
22
23
24
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
25
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
26
27
28
29
from vllm.platforms import current_platform

RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
30
31
ROTARY_OP = torch.ops._C.rotary_embedding.default
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
32
33
34
35
36
37
38
39
40
41

QUANT_OPS: dict[QuantKey, OpOverload] = {
    kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default,  # noqa: E501
}

if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
    QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default  # noqa: E501

42
43
44
45
if current_platform.is_cuda():
    QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501
    QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501

46
47
SILU_MUL_OP = torch.ops._C.silu_and_mul.default

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

class MatcherCustomOp(ABC):
    def __init__(self, enabled: bool):
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
        self.device = config.device_config.device if config.device_config else None

        self.enabled = enabled
        self.forward = self.forward_custom if enabled else self.forward_native

    @abstractmethod
    def forward_custom(self, *args, **kws):
        pass

    @abstractmethod
    def forward_native(self, *args, **kws):
        pass

    def __call__(self, *args, **kws):
        return self.forward(*args, **kws)

    def empty(self, *args, **kws):
        return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)

72
73
74
    def empty_int64(self, *args, **kws):
        return torch.empty(*args, dtype=torch.int64, device=self.device, **kws)

75
76
77
78
79
80
81
82
    def empty_f32(self, *args, **kws):
        return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)

    def inputs(self) -> list[torch.Tensor]:
        """Utility for inputs to the pattern"""
        raise NotImplementedError


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class MatcherRotaryEmbedding(MatcherCustomOp):
    def __init__(
        self,
        is_neox: bool,
        head_size: int,
        num_heads: int,
        num_kv_heads: int,
        use_flashinfer: bool = False,
        enabled: bool | None = None,
    ) -> None:
        if enabled is None:
            enabled = RotaryEmbedding.enabled()

        super().__init__(enabled)
        self.is_neox = is_neox
        self.head_size = head_size
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.q_size = self.num_heads * self.head_size
        self.kv_size = self.num_kv_heads * self.head_size
        self.rotary_dim = head_size
        if use_flashinfer:
            self.rotary_op = FLASHINFER_ROTARY_OP
        else:
            self.rotary_op = ROTARY_OP

    def inputs(self) -> list[torch.Tensor]:
        positions = self.empty_int64(5)
        query = self.empty(5, self.q_size)
        key = self.empty(5, self.kv_size)
        cos_sin_cache = self.empty(4096, self.rotary_dim)
        return [positions, query, key, cos_sin_cache]

    def forward_custom(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None,
        cos_sin_cache: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        result = auto_functionalized(
            self.rotary_op,
            positions=positions,
            query=query,
            key=key,
            head_size=self.head_size,
            cos_sin_cache=cos_sin_cache,
            is_neox=self.is_neox,
        )
        query_out = result[1]
        key_out = result[2] if len(result) > 2 else None
        return query_out, key_out

    def forward_native(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None,
        cos_sin_cache: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        return RotaryEmbedding.forward_static(
            positions,
            query,
            key,
            self.head_size,
            self.rotary_dim,
            cos_sin_cache,
            self.is_neox,
        )


154
class MatcherRMSNorm(MatcherCustomOp):
155
156
157
158
159
160
    def __init__(
        self,
        epsilon: float,
        enabled: bool | None = None,
        match_rocm_aiter: bool = False,
    ):
161
162
163
164
165
        if enabled is None:
            enabled = RMSNorm.enabled()

        super().__init__(enabled)
        self.epsilon = epsilon
166
167
168
169
170
        self._rmsnorm_op = RMS_OP
        self.match_rocm_aiter = match_rocm_aiter

        if match_rocm_aiter:
            self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
171
172
173
174
175
176

    def inputs(self):
        input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
        weight = self.empty(16)
        return [input, weight]

177
178
179
180
181
182
183
184
185
186
187
    def forward_rocm_aiter(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
    ) -> torch.Tensor:
        return self._rmsnorm_op(
            x=input,
            weight=weight,
            variance_epsilon=self.epsilon,
        )

188
189
190
191
192
    def forward_custom(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
    ) -> torch.Tensor:
193
194
195
        if self.match_rocm_aiter:
            return self.forward_rocm_aiter(input, weight)

196
197
        result = torch.empty_like(input)
        _, result = auto_functionalized(
198
            self._rmsnorm_op,
199
            result=result,
200
            input=input,
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
            weight=weight,
            epsilon=self.epsilon,
        )

        return result

    def forward_native(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
    ) -> torch.Tensor:
        return RMSNorm.forward_static(
            input, self.epsilon, input.size(-1), self.model_dtype, weight
        )


class MatcherFusedAddRMSNorm(MatcherCustomOp):
218
219
220
221
222
223
    def __init__(
        self,
        epsilon: float,
        enabled: bool | None = None,
        match_rocm_aiter: bool = False,
    ):
224
225
226
227
228
        if enabled is None:
            enabled = RMSNorm.enabled()

        super().__init__(enabled)
        self.epsilon = epsilon
229
230
231
232
233
234
        self.match_rocm_aiter = match_rocm_aiter

        self._rmsnorm_op = RMS_ADD_OP

        if match_rocm_aiter:
            self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
235
236
237
238
239
240
241

    def inputs(self):
        input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
        weight = self.empty(16)
        residual = self.empty(5, 16)
        return [input, weight, residual]

242
243
244
245
246
247
248
249
250
251
    def forward_rocm_aiter(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self._rmsnorm_op(
            x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
        )

252
253
254
255
256
257
    def forward_custom(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
258
259
260
        if self.match_rocm_aiter:
            return self.forward_rocm_aiter(input, weight, residual)

261
        _, result, residual = auto_functionalized(
262
            self._rmsnorm_op,
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
            input=input,
            residual=residual,
            weight=weight,
            epsilon=self.epsilon,
        )

        return result, residual

    def forward_native(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return RMSNorm.forward_static(
            input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
        )


class MatcherQuantFP8(MatcherCustomOp):
283
284
285
286
    def __init__(
        self,
        quant_key: QuantKey,
        enabled: bool | None = None,
287
288
        has_col_major_scales: bool = False,
        is_e8m0: bool = False,
289
        match_rocm_aiter: bool = False,
290
    ):
291
292
293
294
295
        if enabled is None:
            enabled = QuantFP8.enabled()

        super().__init__(enabled)
        self.quant_key = quant_key
296
297
        self.has_col_major_scales = has_col_major_scales
        self.is_e8m0 = is_e8m0
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        self.match_rocm_aiter = match_rocm_aiter

        if match_rocm_aiter:
            assert not quant_key.scale.group_shape.is_per_tensor(), (
                "ROCm aiter fusion pass does not support per tensor quantization"
            )
            if quant_key.scale.group_shape.is_per_token():
                self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op()
            else:
                assert quant_key.scale.group_shape.col == 128, (
                    "ROCm aiter fusion pass currently supports "
                    "quantization operation with group_size 128"
                )
                if current_platform.is_fp8_fnuz():
                    self.QUANT_OP = rocm_aiter_ops.get_group_quant_op()
                else:
                    self.QUANT_OP = (
                        torch.ops.vllm.triton_per_token_group_quant_fp8.default
                    )

        else:
            assert quant_key in QUANT_OPS, (
                f"unsupported quantization scheme {quant_key}"
            )
            self.QUANT_OP = QUANT_OPS[quant_key]

            assert quant_key.dtype == current_platform.fp8_dtype(), (
                "Only QuantFP8 supported by"
            )
            assert quant_key.scale2 is None
328
329
330
331
332
333
334

        self.quant_fp8 = QuantFP8(
            quant_key.scale.static,
            quant_key.scale.group_shape,
            column_major_scales=has_col_major_scales,
            use_ue8m0=is_e8m0,
        )
335

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    def forward_rocm_aiter(
        self,
        input: torch.Tensor,
        scale: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        quant_key_group_shape = self.quant_key.scale.group_shape
        if quant_key_group_shape == GroupShape.PER_TOKEN:
            return self.QUANT_OP(
                x=input,
                quant_dtype=self.quant_key.dtype,
                scale=scale,
            )
        else:
            return self.QUANT_OP(input, quant_key_group_shape.col)

351
352
353
354
355
    def forward_custom(
        self,
        input: torch.Tensor,
        scale: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
356
357
358
        if self.match_rocm_aiter:
            return self.forward_rocm_aiter(input, scale)

359
360
361
362
        result = torch.empty(
            input.shape, device=input.device, dtype=self.quant_key.dtype
        )

363
364
        if self.quant_key.scale.group_shape.is_per_group():
            assert scale is None
365
            scale = self.make_scale(input, transposed=self.has_col_major_scales)
366
367
368
369
370
371
372
373
374
375
376
377
378
379

            finfo = torch.finfo(self.quant_key.dtype)
            fp8_min = finfo.min
            fp8_max = finfo.max

            _, result, scale = auto_functionalized(
                self.QUANT_OP,
                input=input,
                output_q=result,
                output_s=scale,
                group_size=self.quant_key.scale.group_shape[1],
                eps=1e-10,
                fp8_min=fp8_min,
                fp8_max=fp8_max,
380
                scale_ue8m0=self.is_e8m0,
381
382
383
            )
            return result, scale

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        if self.quant_key.scale.static:
            assert scale is not None
            _, result = auto_functionalized(
                self.QUANT_OP, result=result, input=input, scale=scale
            )
            return result, scale
        else:
            assert scale is None
            scale = self.make_scale(input)
            _, result, scale = auto_functionalized(
                self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
            )
            return result, scale

    def forward_native(
        self,
        input: torch.Tensor,
        scale: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self.quant_fp8(input, scale)

405
    def make_scale(self, input: torch.Tensor, transposed: bool = False):
406
407
408
409
410
411
412
        normalized_group_shape = _normalize_quant_group_shape(
            input, self.quant_key.scale.group_shape
        )
        scale_shape = (
            input.shape[0] // normalized_group_shape[0],
            input.shape[1] // normalized_group_shape[1],
        )
413
414
415
416
417
        if transposed:
            scale_shape = tuple(reversed(scale_shape))
            return torch.empty(
                scale_shape, device=input.device, dtype=torch.float32
            ).permute(-1, -2)
418
419
420
421
422
423
424
425
426

        return torch.empty(scale_shape, device=input.device, dtype=torch.float32)

    def inputs(self) -> list[torch.Tensor]:
        input = self.empty(5, 16)
        if self.quant_key.scale.static:
            return [input, self.empty_f32(1, 1)]

        return [input]
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453


class MatcherSiluAndMul(MatcherCustomOp):
    def __init__(self, enabled: bool | None = None):
        if enabled is None:
            enabled = SiluAndMul.enabled()
        super().__init__(enabled)

    def inputs(self) -> list[torch.Tensor]:
        input = self.empty(5, 4)
        return [input]

    def forward_custom(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
        return result[1]

    def forward_native(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return SiluAndMul.forward_native(x)