fp8_utils.py 18.1 KB
Newer Older
1
import os
HandH1998's avatar
HandH1998 committed
2
from typing import List, Optional, Tuple
HAI's avatar
HAI committed
3
4

import torch
HandH1998's avatar
HandH1998 committed
5

Yineng Zhang's avatar
Yineng Zhang committed
6
7
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8

Lianmin Zheng's avatar
Lianmin Zheng committed
8
try:
9
    from vllm import _custom_ops as ops
Lianmin Zheng's avatar
Lianmin Zheng committed
10
11
12
13
14

    VLLM_AVAILABLE = True
except ImportError:
    VLLM_AVAILABLE = False

HandH1998's avatar
HandH1998 committed
15
from sglang.srt.layers.quantization.fp8_kernel import (
16
    _enable_jit_deepgemm,
HandH1998's avatar
HandH1998 committed
17
    per_token_group_quant_fp8,
Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
    scaled_fp8_quant,
    sglang_per_token_quant_fp8,
HandH1998's avatar
HandH1998 committed
20
    static_quant_fp8,
HandH1998's avatar
HandH1998 committed
21
22
    w8a8_block_fp8_matmul,
)
HandH1998's avatar
HandH1998 committed
23
24
25
26
from sglang.srt.utils import (
    get_bool_env_var,
    get_cuda_version,
    get_device_capability,
Lianmin Zheng's avatar
Lianmin Zheng committed
27
    is_cuda,
HandH1998's avatar
HandH1998 committed
28
29
30
    is_hip,
)

31
_is_hip = is_hip()
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33
_is_cuda = is_cuda()

34
if _is_hip and get_bool_env_var("CK_MOE"):
yigex's avatar
yigex committed
35
36
    from aiter import gemm_a8w8_blockscale

37
if _is_cuda:
38
    from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
HAI's avatar
HAI committed
39

Lianmin Zheng's avatar
Lianmin Zheng committed
40
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
HandH1998's avatar
HandH1998 committed
41

HandH1998's avatar
HandH1998 committed
42
43
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
Lianmin Zheng's avatar
Lianmin Zheng committed
44
TORCH_DEVICE_IDENTITY = None
HandH1998's avatar
HandH1998 committed
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
_TORCH_VERSION = torch.__version__.split("+")[0]
try:
    _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
except ValueError:
    _TORCH_VERSION_TUPLE = (0, 0, 0)

# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time consuming.
USE_ROWWISE_TORCH_SCALED_MM = (
    _is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
)

HandH1998's avatar
HandH1998 committed
60
61
62
63
64
65
66
67
68
69
70
71

def cutlass_fp8_supported():
    if not _is_cuda:
        return False
    major, minor = get_device_capability()
    cuda_version = get_cuda_version()
    if major >= 9:
        return cuda_version >= (12, 0)
    elif major == 8 and minor == 9:
        return cuda_version >= (12, 4)
    return False

HAI's avatar
HAI committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

def normalize_e4m3fn_to_e4m3fnuz(
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    assert weight.dtype == torch.float8_e4m3fn
    # The bits pattern 10000000(-128) represents zero in e4m3fn
    # but NaN in e4m3fnuz. So here we set it to 0.
    # https://onnx.ai/onnx/technical/float8.html
    weight_as_int8 = weight.view(torch.int8)
    ROCM_FP8_NAN_AS_INT = -128
    weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
    weight = weight_as_int8.view(torch.float8_e4m3fnuz)

    # For the same bits representation, e4m3fnuz value is half of
    # the e4m3fn value, so we should double the scaling factor to
    # get the same dequantized value.
    # https://onnx.ai/onnx/technical/float8.html
    weight_scale = weight_scale * 2.0
    if input_scale is not None:
        input_scale = input_scale * 2.0
    return weight, weight_scale, input_scale
HandH1998's avatar
HandH1998 committed
95
96


97
def cutlass_block_fp8_supported() -> bool:
98
    if not get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
99
        return False
100
101
102
103
104
105
106
107
108
109
110
111
    if _is_cuda:
        major, minor = torch.cuda.get_device_capability()
        sm_version = major * 10 + minor
        cuda_version = tuple(map(int, torch.version.cuda.split(".")))
        if cuda_version >= (12, 0) and sm_version >= 90:
            return True
    return False


CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()


HandH1998's avatar
HandH1998 committed
112
113
114
115
116
117
118
119
120
121
122
123
def apply_w8a8_block_fp8_linear(
    input: torch.Tensor,
    weight: torch.Tensor,
    block_size: List[int],
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    assert input_scale is None
    # View input as 2D matrix for fp8 methods
    input_2d = input.view(-1, input.shape[-1])
    output_shape = [*input.shape[:-1], weight.shape[0]]
124
125
126
    # TODO: add more robust shape check here
    shape_supported_by_cutlass = (
        weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
HandH1998's avatar
HandH1998 committed
127
    )
128
129
130
131
132
133
134
    if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
        q_input, x_scale = per_token_group_quant_fp8(
            input_2d, block_size[1], column_major_scales=True
        )
        output = fp8_blockwise_scaled_mm(
            q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
        )
135
    elif _is_hip and get_bool_env_var("CK_MOE"):
yigex's avatar
yigex committed
136
137
138
139
140
141
142
143
144
        q_input, x_scale = per_token_group_quant_fp8(
            input_2d, block_size[1], column_major_scales=False
        )
        output = torch.zeros(
            [q_input.shape[0], weight.shape[0]],
            dtype=input.dtype,
            device=q_input.device,
        )
        gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
145
    else:
146
        if _enable_jit_deepgemm:
147
            q_input, x_scale = sglang_per_token_group_quant_fp8(
148
149
150
151
152
153
154
155
156
                input_2d,
                block_size[1],
                column_major_scales=True,
                scale_tma_aligned=True,
            )
        else:
            q_input, x_scale = per_token_group_quant_fp8(
                input_2d, block_size[1], column_major_scales=False
            )
157
158
159
        output = w8a8_block_fp8_matmul(
            q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
        )
HandH1998's avatar
HandH1998 committed
160
161
162
163
164
165
166
167
168
169
170
171

    if bias is not None:
        output = output + bias
    return output.to(dtype=input.dtype).view(*output_shape)


def input_to_float8(
    x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
) -> Tuple[torch.Tensor, torch.Tensor]:
    """This function quantizes input values to float8 values with tensor-wise quantization."""
    finfo = torch.finfo(dtype)
    min_val, max_val = x.aminmax()
172
    amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
173
    fp8_max = finfo.max
174
    if _is_hip:
175
        dtype = torch.float8_e4m3fnuz
176
177
        fp8_max = 224.0
    scale = fp8_max / amax
178
    x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
HandH1998's avatar
HandH1998 committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()


def block_quant_to_tensor_quant(
    x_q_block: torch.Tensor,
    x_s: torch.Tensor,
    block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """This function converts block-wise quantization to tensor-wise quantization.
    The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
    and the block size.
    The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
    Note only float8 is supported for now.
    """
    block_n, block_k = block_size[0], block_size[1]
    n, k = x_q_block.shape
    n_tiles = (n + block_n - 1) // block_n
    k_tiles = (k + block_k - 1) // block_k
    assert n_tiles == x_s.shape[0]
    assert k_tiles == x_s.shape[1]

    x_dq_block = x_q_block.to(torch.float32)

    x_dq_block_tiles = [
        [
            x_dq_block[
                j * block_n : min((j + 1) * block_n, n),
                i * block_k : min((i + 1) * block_k, k),
            ]
            for i in range(k_tiles)
        ]
        for j in range(n_tiles)
    ]

    for i in range(k_tiles):
        for j in range(n_tiles):
            x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]

217
    x_q_tensor, scale = (
Lianmin Zheng's avatar
Lianmin Zheng committed
218
        scaled_fp8_quant(x_dq_block)
219
220
221
        if _is_cuda
        else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
    )
HandH1998's avatar
HandH1998 committed
222
223
224
    return x_q_tensor, scale


225
226
227
228
229
def channel_quant_to_tensor_quant(
    x_q_channel: torch.Tensor,
    x_s: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    x_dq_channel = x_q_channel.to(torch.float32) * x_s
230
    x_q_tensor, scale = (
Lianmin Zheng's avatar
Lianmin Zheng committed
231
        scaled_fp8_quant(x_dq_channel)
232
233
234
        if _is_cuda
        else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
    )
235
236
237
    return x_q_tensor, scale


238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def _process_scaled_mm_output(output, input_2d_shape, output_shape):
    if type(output) is tuple and len(output) == 2:
        output = output[0]
    return torch.narrow(output, 0, 0, input_2d_shape[0]).view(*output_shape)


def _apply_fallback_scaled_mm(
    qinput,
    weight,
    x_scale,
    weight_scale,
    input_2d_shape,
    output_shape,
    bias,
    input_dtype,
):
    global TORCH_DEVICE_IDENTITY
    if TORCH_DEVICE_IDENTITY is None:
        TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32, device=weight.device)

    output = torch._scaled_mm(
        qinput,
        weight,
        scale_a=TORCH_DEVICE_IDENTITY,
        scale_b=TORCH_DEVICE_IDENTITY,
        out_dtype=torch.float32,
    )

    output = _process_scaled_mm_output(output, input_2d_shape, output_shape)
    x_scale = torch.narrow(x_scale, 0, 0, input_2d_shape[0])

    output = output * x_scale * weight_scale.t()
    if bias is not None:
        output = output + bias
    return output.to(dtype=input_dtype)


HandH1998's avatar
HandH1998 committed
275
276
277
278
279
280
281
def apply_fp8_linear(
    input: torch.Tensor,
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
    input_scale_ub: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
282
    cutlass_fp8_supported: bool = cutlass_fp8_supported(),
HandH1998's avatar
HandH1998 committed
283
    use_per_token_if_dynamic: bool = False,
284
285
    pad_output: Optional[bool] = None,
    compressed_tensor_quant: bool = False,
HandH1998's avatar
HandH1998 committed
286
) -> torch.Tensor:
287
288
289
290
291
292
293
294
295
    # Note: we pad the input because torch._scaled_mm is more performant
    # for matrices with batch dimension > 16.
    # This could change in the future.
    # We also don't pad when using torch.compile,
    # as it breaks with dynamic shapes.
    if pad_output is None:
        pad_output = not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
    output_padding = 17 if pad_output else None

HandH1998's avatar
HandH1998 committed
296
297
298
299
    # View input as 2D matrix for fp8 methods
    input_2d = input.view(-1, input.shape[-1])
    output_shape = [*input.shape[:-1], weight.shape[1]]

300
    if compressed_tensor_quant:
301
302
        # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
        # for sgl-kernel fp8_scaled_mm, it support per channel W now
303
304
305
306
307
308
        if cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
            qinput, x_scale = scaled_fp8_quant(
                input_2d,
                input_scale,
                use_per_token_if_dynamic=use_per_token_if_dynamic,
            )
309
310
311
312

            # Fused GEMM_DQ
            if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
                # Fall back to vllm cutlass w8a8 fp8 kernel
313
                output = ops.cutlass_scaled_mm(
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
                    qinput,
                    weight,
                    out_dtype=input.dtype,
                    scale_a=x_scale,
                    scale_b=weight_scale,
                    bias=bias,
                )
            else:
                assert (
                    weight_scale.numel() == weight.shape[1]
                ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
                output = fp8_scaled_mm(
                    qinput,
                    weight,
                    x_scale,
                    weight_scale,
                    out_dtype=input.dtype,
                    bias=bias,
                )
            return output.view(*output_shape)

        # torch.scaled_mm supports per tensor weights + activations only
        # so fallback to naive if per channel or per token
        else:
            # Maybe apply padding to output, see comment in __init__
339
340
            qinput, x_scale = (
                scaled_fp8_quant(
341
342
                    input_2d,
                    input_scale,
343
                    num_token_padding=output_padding,
344
345
                    use_per_token_if_dynamic=use_per_token_if_dynamic,
                )
346
347
                if _is_cuda
                else ops.scaled_fp8_quant(
348
349
                    input_2d,
                    input_scale,
350
                    num_token_padding=output_padding,
351
352
                    use_per_token_if_dynamic=use_per_token_if_dynamic,
                )
353
            )
354
355
356
357
358
359
360
361
362
363
364
365
366
367

            per_tensor_weights = weight_scale.numel() == 1
            per_tensor_activations = x_scale.numel() == 1

            if per_tensor_weights and per_tensor_activations:
                # Fused GEMM_DQ
                output = torch._scaled_mm(
                    qinput,
                    weight,
                    out_dtype=input.dtype,
                    scale_a=x_scale,
                    scale_b=weight_scale,
                    bias=bias,
                )
368
                return _process_scaled_mm_output(output, input_2d.shape, output_shape)
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390

            elif (
                use_per_token_if_dynamic
                and not per_tensor_weights
                and not per_tensor_activations
                and USE_ROWWISE_TORCH_SCALED_MM
            ):
                # For now validated on ROCm platform
                # fp8 rowwise scaling in torch._scaled_mm is introduced in
                # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
                # and ROCm 6.3, which only exists in torch 2.7 and above.
                # For CUDA platform please validate if the
                # torch._scaled_mm support rowwise scaled GEMM
                # Fused GEMM_DQ Rowwise GEMM
                output = torch._scaled_mm(
                    qinput,
                    weight,
                    out_dtype=input.dtype,
                    scale_a=x_scale,
                    scale_b=weight_scale.t(),
                    bias=bias,
                )
391
                return _process_scaled_mm_output(output, input_2d.shape, output_shape)
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407

            else:
                # Fallback for channelwise case, where we use unfused DQ
                # due to limitations with scaled_mm

                # Symmetric quantized GEMM by definition computes the following:
                #   C = (s_x * X) (s_w * W) + bias
                # This is equivalent to dequantizing the weights and activations
                # before applying a GEMM.
                #
                # In order to compute quantized operands, a quantized kernel
                # will rewrite the above like so:
                #   C = s_w * s_x * (X * W) + bias
                #
                # For the scaled_mm fallback case, we break this down, since it
                # does not support s_w being a vector.
408
                return _apply_fallback_scaled_mm(
409
410
                    qinput,
                    weight,
411
412
413
414
415
416
                    x_scale,
                    weight_scale,
                    input_2d.shape,
                    output_shape,
                    bias,
                    input.dtype,
417
                )
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
    else:
        # cutlass w8a8 fp8 sgl-kernel only supports per-token scale
        if input_scale is not None:
            assert input_scale.numel() == 1
            # broadcast per-tensor scale to per-token scale when supporting cutlass
            qinput, x_scale = static_quant_fp8(
                input_2d, input_scale, repeat_scale=cutlass_fp8_supported
            )
        else:
            # default use per-token quantization if dynamic
            if _is_cuda:
                qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
            else:
                # TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
                # final solution should be: 1. add support to per-tensor activation scaling.
                # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
                if _is_hip and weight_scale.numel() == 1:
                    qinput, x_scale = ops.scaled_fp8_quant(
                        input_2d,
                        input_scale,
                        use_per_token_if_dynamic=use_per_token_if_dynamic,
                    )
                else:
                    qinput, x_scale = per_token_group_quant_fp8(
                        input_2d, group_size=input_2d.shape[1]
                    )

        if cutlass_fp8_supported:
            try:
                if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
                    # Fall back to vllm cutlass w8a8 fp8 kernel
                    output = ops.cutlass_scaled_mm(
                        qinput,
                        weight,
                        out_dtype=input.dtype,
                        scale_a=x_scale,
                        scale_b=weight_scale,
                        bias=bias,
                    )
                else:
                    assert (
                        weight_scale.numel() == weight.shape[1]
                    ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
                    output = fp8_scaled_mm(
                        qinput,
                        weight,
                        x_scale,
                        weight_scale,
                        out_dtype=input.dtype,
                        bias=bias,
                    )
                return output.view(*output_shape)
            except (ImportError, NameError, AttributeError):
                pass

        # torch.scaled_mm supports per tensor weights + activations only
        # so fallback to naive if per channel or per token
        per_tensor_weights = weight_scale.numel() == 1
        per_tensor_activations = x_scale.numel() == 1

        if per_tensor_weights and per_tensor_activations:
            # Fused GEMM_DQ
            output = torch._scaled_mm(
                qinput,
                weight,
                out_dtype=input.dtype,
                scale_a=x_scale,
                scale_b=weight_scale,
                bias=bias,
            )
            return _process_scaled_mm_output(output, input_2d.shape, output_shape)

        else:
            # Fallback for channelwise case, where we use unfused DQ
            # due to limitations with scaled_mm

            # Symmetric quantized GEMM by definition computes the following:
            #   C = (s_x * X) (s_w * W) + bias
            # This is equivalent to dequantizing the weights and activations
            # before applying a GEMM.
            #
            # In order to compute quantized operands, a quantized kernel
            # will rewrite the above like so:
            #   C = s_w * s_x * (X * W) + bias
            #
            # For the scaled_mm fallback case, we break this down, since it
            # does not support s_w being a vector.
            return _apply_fallback_scaled_mm(
                qinput,
                weight,
                x_scale,
                weight_scale,
                input_2d.shape,
                output_shape,
                bias,
                input.dtype,
            )