fp8_utils.py 23.2 KB
Newer Older
1
import os
2
3
from curses import flash
from typing import Callable, List, Optional, Tuple
HAI's avatar
HAI committed
4
5

import torch
HandH1998's avatar
HandH1998 committed
6

Yineng Zhang's avatar
Yineng Zhang committed
7
8
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8

Lianmin Zheng's avatar
Lianmin Zheng committed
9
try:
10
    from vllm import _custom_ops as ops
Lianmin Zheng's avatar
Lianmin Zheng committed
11
12
13
14
15

    VLLM_AVAILABLE = True
except ImportError:
    VLLM_AVAILABLE = False

16
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
HandH1998's avatar
HandH1998 committed
17
from sglang.srt.layers.quantization.fp8_kernel import (
18
19
20
    fp8_dtype,
    fp8_max,
    is_fp8_fnuz,
HandH1998's avatar
HandH1998 committed
21
    per_token_group_quant_fp8,
Lianmin Zheng's avatar
Lianmin Zheng committed
22
23
    scaled_fp8_quant,
    sglang_per_token_quant_fp8,
HandH1998's avatar
HandH1998 committed
24
    static_quant_fp8,
25
26
    w8a8_block_fp8_matmul_deepgemm,
    w8a8_block_fp8_matmul_triton,
HandH1998's avatar
HandH1998 committed
27
)
HandH1998's avatar
HandH1998 committed
28
29
30
31
from sglang.srt.utils import (
    get_bool_env_var,
    get_cuda_version,
    get_device_capability,
Lianmin Zheng's avatar
Lianmin Zheng committed
32
    is_cuda,
33
    is_flashinfer_available,
HandH1998's avatar
HandH1998 committed
34
35
36
    is_hip,
)

37
_is_hip = is_hip()
Lianmin Zheng's avatar
Lianmin Zheng committed
38
_is_cuda = is_cuda()
39
_is_fp8_fnuz = is_fp8_fnuz()
Lianmin Zheng's avatar
Lianmin Zheng committed
40

41

42
43
44
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")

if _is_hip and use_aiter_moe:
yigex's avatar
yigex committed
45
46
    from aiter import gemm_a8w8_blockscale

47
if _is_cuda:
48
    from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
HAI's avatar
HAI committed
49

Lianmin Zheng's avatar
Lianmin Zheng committed
50
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
HandH1998's avatar
HandH1998 committed
51

HandH1998's avatar
HandH1998 committed
52
53
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
Lianmin Zheng's avatar
Lianmin Zheng committed
54
TORCH_DEVICE_IDENTITY = None
HandH1998's avatar
HandH1998 committed
55

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

def use_rowwise_torch_scaled_mm():
    _TORCH_VERSION = torch.__version__.split("+")[0]
    try:
        _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
    except ValueError:
        _TORCH_VERSION_TUPLE = (0, 0, 0)
    if _is_hip:
        # The condition to determine if it is on a platform that supports
        # torch._scaled_mm rowwise feature.
        # The condition is determined once as the operations
        # are time consuming.
        return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
    return False


USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm()
73

HandH1998's avatar
HandH1998 committed
74
75
76
77
78
79
80
81
82
83
84
85

def cutlass_fp8_supported():
    if not _is_cuda:
        return False
    major, minor = get_device_capability()
    cuda_version = get_cuda_version()
    if major >= 9:
        return cuda_version >= (12, 0)
    elif major == 8 and minor == 9:
        return cuda_version >= (12, 4)
    return False

HAI's avatar
HAI committed
86

87
88
89
90
91
92
def is_sm100_supported(device=None) -> bool:
    return (torch.cuda.get_device_capability(device)[0] == 10) and (
        torch.version.cuda >= "12.8"
    )


HAI's avatar
HAI committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def normalize_e4m3fn_to_e4m3fnuz(
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    assert weight.dtype == torch.float8_e4m3fn
    # The bits pattern 10000000(-128) represents zero in e4m3fn
    # but NaN in e4m3fnuz. So here we set it to 0.
    # https://onnx.ai/onnx/technical/float8.html
    weight_as_int8 = weight.view(torch.int8)
    ROCM_FP8_NAN_AS_INT = -128
    weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
    weight = weight_as_int8.view(torch.float8_e4m3fnuz)

    # For the same bits representation, e4m3fnuz value is half of
    # the e4m3fn value, so we should double the scaling factor to
    # get the same dequantized value.
    # https://onnx.ai/onnx/technical/float8.html
    weight_scale = weight_scale * 2.0
    if input_scale is not None:
        input_scale = input_scale * 2.0
    return weight, weight_scale, input_scale
HandH1998's avatar
HandH1998 committed
115
116


117
def cutlass_block_fp8_supported() -> bool:
118
    if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
119
        return False
120
121
122
123
124
125
126
127
128
129
    if _is_cuda:
        major, minor = torch.cuda.get_device_capability()
        sm_version = major * 10 + minor
        cuda_version = tuple(map(int, torch.version.cuda.split(".")))
        if cuda_version >= (12, 0) and sm_version >= 90:
            return True
    return False


CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
130
131
132
133
134
135
136
ENABLE_FLASHINFER_GEMM = (
    get_bool_env_var("SGLANG_ENABLE_FLASHINFER_GEMM")
    and is_sm100_supported()
    and is_flashinfer_available()
)
if ENABLE_FLASHINFER_GEMM:
    from flashinfer.gemm import gemm_fp8_nt_groupwise
137
138


139
140
141
142
143
144
145
146
147
148
149
150
151
152
def dispatch_w8a8_block_fp8_linear() -> Callable:
    if ENABLE_FLASHINFER_GEMM:
        return flashinfer_gemm_w8a8_block_fp8_linear
    elif CUTLASS_BLOCK_FP8_SUPPORTED:
        return cutlass_w8a8_block_fp8_linear_with_fallback
    elif _is_hip and use_aiter_moe:
        return aiter_w8a8_block_fp8_linear
    elif _ENABLE_JIT_DEEPGEMM:
        return deepgemm_w8a8_block_fp8_linear_with_fallback
    else:
        return triton_w8a8_block_fp8_linear


def flashinfer_gemm_w8a8_block_fp8_linear(
HandH1998's avatar
HandH1998 committed
153
154
155
156
157
158
159
160
    input: torch.Tensor,
    weight: torch.Tensor,
    block_size: List[int],
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    assert input_scale is None
161

HandH1998's avatar
HandH1998 committed
162
163
    input_2d = input.view(-1, input.shape[-1])
    output_shape = [*input.shape[:-1], weight.shape[0]]
164
165
166

    q_input, x_scale = sglang_per_token_group_quant_fp8(
        input_2d, block_size[1], column_major_scales=False
HandH1998's avatar
HandH1998 committed
167
    )
168
169

    output = gemm_fp8_nt_groupwise(
170
171
172
173
174
175
        q_input,
        weight,
        x_scale,
        weight_scale,
        scale_major_mode="K",
        out_dtype=input_2d.dtype,
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    )

    if bias is not None:
        output += bias

    return output.to(dtype=input_2d.dtype).view(*output_shape)


def cutlass_w8a8_block_fp8_linear_with_fallback(
    input: torch.Tensor,
    weight: torch.Tensor,
    block_size: List[int],
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    assert input_scale is None

    # TODO: add more robust shape check here
    shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0

    if not shape_supported:
        # fallback to triton
        return triton_w8a8_block_fp8_linear(
            input, weight, block_size, weight_scale, input_scale, bias
yigex's avatar
yigex committed
201
        )
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

    input_2d = input.view(-1, input.shape[-1])
    output_shape = [*input.shape[:-1], weight.shape[0]]

    q_input, x_scale = per_token_group_quant_fp8(
        input_2d, block_size[1], column_major_scales=True
    )
    output = fp8_blockwise_scaled_mm(
        q_input, weight.T, x_scale, weight_scale.T, out_dtype=input_2d.dtype
    )
    if bias is not None:
        output += bias
    return output.to(dtype=input_2d.dtype).view(*output_shape)


def deepgemm_w8a8_block_fp8_linear_with_fallback(
    input: torch.Tensor,
    weight: torch.Tensor,
    block_size: List[int],
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    assert input_scale is None

    output_dtype = input.dtype
    dtype_supported = output_dtype == torch.bfloat16

    # TODO: add more robust shape check here
    shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0

    if not (shape_supported and dtype_supported):
        # fall back to triton
        return triton_w8a8_block_fp8_linear(
            input, weight, block_size, weight_scale, input_scale, bias
237
        )
HandH1998's avatar
HandH1998 committed
238

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    input_2d = input.view(-1, input.shape[-1])
    output_shape = [*input.shape[:-1], weight.shape[0]]

    q_input, x_scale = sglang_per_token_group_quant_fp8(
        input_2d,
        block_size[1],
        column_major_scales=True,
        scale_tma_aligned=True,
    )
    output = w8a8_block_fp8_matmul_deepgemm(
        q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
    )
    if bias is not None:
        output += bias
    return output.to(dtype=output_dtype).view(*output_shape)


def aiter_w8a8_block_fp8_linear(
    input: torch.Tensor,
    weight: torch.Tensor,
    block_size: List[int],
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    assert input_scale is None
    input_2d = input.view(-1, input.shape[-1])
    output_shape = [*input.shape[:-1], weight.shape[0]]

    q_input, x_scale = per_token_group_quant_fp8(
        input_2d, block_size[1], column_major_scales=False
    )
    output = torch.zeros(
        [q_input.shape[0], weight.shape[0]],
        dtype=input_2d.dtype,
        device=q_input.device,
    )
    gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)

    if bias is not None:
        output += bias

    return output.to(dtype=input_2d.dtype).view(*output_shape)


def triton_w8a8_block_fp8_linear(
    input: torch.Tensor,
    weight: torch.Tensor,
    block_size: List[int],
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    assert input_scale is None
    input_2d = input.view(-1, input.shape[-1])
    output_shape = [*input.shape[:-1], weight.shape[0]]

    q_input, x_scale = per_token_group_quant_fp8(
        input_2d, block_size[1], column_major_scales=False
    )
    output = w8a8_block_fp8_matmul_triton(
        q_input, weight, x_scale, weight_scale, block_size, output_dtype=input_2d.dtype
    )
HandH1998's avatar
HandH1998 committed
302
    if bias is not None:
303
304
        output += bias
    return output.to(dtype=input_2d.dtype).view(*output_shape)
HandH1998's avatar
HandH1998 committed
305
306
307


def input_to_float8(
308
    x: torch.Tensor, dtype: torch.dtype = fp8_dtype
HandH1998's avatar
HandH1998 committed
309
310
311
) -> Tuple[torch.Tensor, torch.Tensor]:
    """This function quantizes input values to float8 values with tensor-wise quantization."""
    min_val, max_val = x.aminmax()
312
    amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
313
314
315
316
317
318
319
320
321
322

    if _is_fp8_fnuz:
        dtype = fp8_dtype
        fp_max = fp8_max
    else:
        finfo = torch.finfo(dtype)
        fp_max = finfo.max

    scale = fp_max / amax
    x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max)
HandH1998's avatar
HandH1998 committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
    return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()


def block_quant_to_tensor_quant(
    x_q_block: torch.Tensor,
    x_s: torch.Tensor,
    block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """This function converts block-wise quantization to tensor-wise quantization.
    The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
    and the block size.
    The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
    Note only float8 is supported for now.
    """
    block_n, block_k = block_size[0], block_size[1]
    n, k = x_q_block.shape
    n_tiles = (n + block_n - 1) // block_n
    k_tiles = (k + block_k - 1) // block_k
    assert n_tiles == x_s.shape[0]
    assert k_tiles == x_s.shape[1]

    x_dq_block = x_q_block.to(torch.float32)

    x_dq_block_tiles = [
        [
            x_dq_block[
                j * block_n : min((j + 1) * block_n, n),
                i * block_k : min((i + 1) * block_k, k),
            ]
            for i in range(k_tiles)
        ]
        for j in range(n_tiles)
    ]

    for i in range(k_tiles):
        for j in range(n_tiles):
            x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]

361
    x_q_tensor, scale = (
Lianmin Zheng's avatar
Lianmin Zheng committed
362
        scaled_fp8_quant(x_dq_block)
363
364
365
        if _is_cuda
        else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
    )
HandH1998's avatar
HandH1998 committed
366
367
368
    return x_q_tensor, scale


369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def block_quant_dequant(
    x_q_block: torch.Tensor,
    x_s: torch.Tensor,
    block_size: List[int],
    dtype: torch.dtype,
) -> torch.Tensor:
    """This function converts block-wise quantization to unquantized.
    The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
    and the block size.
    The output is an unquantized tensor with dtype.
    """
    block_n, block_k = block_size[0], block_size[1]
    n, k = x_q_block.shape
    n_tiles = (n + block_n - 1) // block_n
    k_tiles = (k + block_k - 1) // block_k
    assert n_tiles == x_s.shape[0]
    assert k_tiles == x_s.shape[1]

    x_dq_block = torch.empty_like(x_q_block, dtype=dtype)

    for j in range(n_tiles):
        for i in range(k_tiles):
            x_q_block_tile = x_q_block[
                j * block_n : min((j + 1) * block_n, n),
                i * block_k : min((i + 1) * block_k, k),
            ]
            x_dq_block_tile = x_dq_block[
                j * block_n : min((j + 1) * block_n, n),
                i * block_k : min((i + 1) * block_k, k),
            ]
            x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]

    return x_dq_block


404
405
406
407
408
def channel_quant_to_tensor_quant(
    x_q_channel: torch.Tensor,
    x_s: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    x_dq_channel = x_q_channel.to(torch.float32) * x_s
409
    x_q_tensor, scale = (
Lianmin Zheng's avatar
Lianmin Zheng committed
410
        scaled_fp8_quant(x_dq_channel)
411
412
413
        if _is_cuda
        else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
    )
414
415
416
    return x_q_tensor, scale


417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
def _process_scaled_mm_output(output, input_2d_shape, output_shape):
    if type(output) is tuple and len(output) == 2:
        output = output[0]
    return torch.narrow(output, 0, 0, input_2d_shape[0]).view(*output_shape)


def _apply_fallback_scaled_mm(
    qinput,
    weight,
    x_scale,
    weight_scale,
    input_2d_shape,
    output_shape,
    bias,
    input_dtype,
):
    global TORCH_DEVICE_IDENTITY
    if TORCH_DEVICE_IDENTITY is None:
        TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32, device=weight.device)

    output = torch._scaled_mm(
        qinput,
        weight,
        scale_a=TORCH_DEVICE_IDENTITY,
        scale_b=TORCH_DEVICE_IDENTITY,
        out_dtype=torch.float32,
    )

    output = _process_scaled_mm_output(output, input_2d_shape, output_shape)
    x_scale = torch.narrow(x_scale, 0, 0, input_2d_shape[0])

    output = output * x_scale * weight_scale.t()
    if bias is not None:
        output = output + bias
    return output.to(dtype=input_dtype)


HandH1998's avatar
HandH1998 committed
454
455
456
457
458
459
460
def apply_fp8_linear(
    input: torch.Tensor,
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
    input_scale_ub: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
461
    cutlass_fp8_supported: bool = cutlass_fp8_supported(),
HandH1998's avatar
HandH1998 committed
462
    use_per_token_if_dynamic: bool = False,
463
464
    pad_output: Optional[bool] = None,
    compressed_tensor_quant: bool = False,
HandH1998's avatar
HandH1998 committed
465
) -> torch.Tensor:
466
467
468
469
470
471
472
473
474
    # Note: we pad the input because torch._scaled_mm is more performant
    # for matrices with batch dimension > 16.
    # This could change in the future.
    # We also don't pad when using torch.compile,
    # as it breaks with dynamic shapes.
    if pad_output is None:
        pad_output = not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
    output_padding = 17 if pad_output else None

HandH1998's avatar
HandH1998 committed
475
476
477
478
    # View input as 2D matrix for fp8 methods
    input_2d = input.view(-1, input.shape[-1])
    output_shape = [*input.shape[:-1], weight.shape[1]]

479
    if compressed_tensor_quant:
480
481
        # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
        # for sgl-kernel fp8_scaled_mm, it support per channel W now
482
483
484
485
486
487
        if cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
            qinput, x_scale = scaled_fp8_quant(
                input_2d,
                input_scale,
                use_per_token_if_dynamic=use_per_token_if_dynamic,
            )
488
489
490

            # Fused GEMM_DQ
            if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
491
                # Fall back to vllm cutlass w8a8 fp8 kernel
492
                output = ops.cutlass_scaled_mm(
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
                    qinput,
                    weight,
                    out_dtype=input.dtype,
                    scale_a=x_scale,
                    scale_b=weight_scale,
                    bias=bias,
                )
            else:
                assert (
                    weight_scale.numel() == weight.shape[1]
                ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
                output = fp8_scaled_mm(
                    qinput,
                    weight,
                    x_scale,
                    weight_scale,
                    out_dtype=input.dtype,
                    bias=bias,
                )
            return output.view(*output_shape)

        # torch.scaled_mm supports per tensor weights + activations only
        # so fallback to naive if per channel or per token
        else:
            # Maybe apply padding to output, see comment in __init__
518
519
            qinput, x_scale = (
                scaled_fp8_quant(
520
521
                    input_2d,
                    input_scale,
522
                    num_token_padding=output_padding,
523
524
                    use_per_token_if_dynamic=use_per_token_if_dynamic,
                )
525
526
                if _is_cuda
                else ops.scaled_fp8_quant(
527
528
                    input_2d,
                    input_scale,
529
                    num_token_padding=output_padding,
530
531
                    use_per_token_if_dynamic=use_per_token_if_dynamic,
                )
532
            )
533
534
535
536
537
538
539
540
541
542
543
544
545
546

            per_tensor_weights = weight_scale.numel() == 1
            per_tensor_activations = x_scale.numel() == 1

            if per_tensor_weights and per_tensor_activations:
                # Fused GEMM_DQ
                output = torch._scaled_mm(
                    qinput,
                    weight,
                    out_dtype=input.dtype,
                    scale_a=x_scale,
                    scale_b=weight_scale,
                    bias=bias,
                )
547
                return _process_scaled_mm_output(output, input_2d.shape, output_shape)
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569

            elif (
                use_per_token_if_dynamic
                and not per_tensor_weights
                and not per_tensor_activations
                and USE_ROWWISE_TORCH_SCALED_MM
            ):
                # For now validated on ROCm platform
                # fp8 rowwise scaling in torch._scaled_mm is introduced in
                # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
                # and ROCm 6.3, which only exists in torch 2.7 and above.
                # For CUDA platform please validate if the
                # torch._scaled_mm support rowwise scaled GEMM
                # Fused GEMM_DQ Rowwise GEMM
                output = torch._scaled_mm(
                    qinput,
                    weight,
                    out_dtype=input.dtype,
                    scale_a=x_scale,
                    scale_b=weight_scale.t(),
                    bias=bias,
                )
570
                return _process_scaled_mm_output(output, input_2d.shape, output_shape)
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586

            else:
                # Fallback for channelwise case, where we use unfused DQ
                # due to limitations with scaled_mm

                # Symmetric quantized GEMM by definition computes the following:
                #   C = (s_x * X) (s_w * W) + bias
                # This is equivalent to dequantizing the weights and activations
                # before applying a GEMM.
                #
                # In order to compute quantized operands, a quantized kernel
                # will rewrite the above like so:
                #   C = s_w * s_x * (X * W) + bias
                #
                # For the scaled_mm fallback case, we break this down, since it
                # does not support s_w being a vector.
587
                return _apply_fallback_scaled_mm(
588
589
                    qinput,
                    weight,
590
591
592
593
594
595
                    x_scale,
                    weight_scale,
                    input_2d.shape,
                    output_shape,
                    bias,
                    input.dtype,
596
                )
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
    else:
        # cutlass w8a8 fp8 sgl-kernel only supports per-token scale
        if input_scale is not None:
            assert input_scale.numel() == 1
            # broadcast per-tensor scale to per-token scale when supporting cutlass
            qinput, x_scale = static_quant_fp8(
                input_2d, input_scale, repeat_scale=cutlass_fp8_supported
            )
        else:
            # default use per-token quantization if dynamic
            if _is_cuda:
                qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
            else:
                # TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
                # final solution should be: 1. add support to per-tensor activation scaling.
                # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
                if _is_hip and weight_scale.numel() == 1:
                    qinput, x_scale = ops.scaled_fp8_quant(
                        input_2d,
                        input_scale,
                        use_per_token_if_dynamic=use_per_token_if_dynamic,
                    )
                else:
                    qinput, x_scale = per_token_group_quant_fp8(
                        input_2d, group_size=input_2d.shape[1]
                    )

        if cutlass_fp8_supported:
            try:
                if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
627
                    # Fall back to vllm cutlass w8a8 fp8 kernel
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
                    output = ops.cutlass_scaled_mm(
                        qinput,
                        weight,
                        out_dtype=input.dtype,
                        scale_a=x_scale,
                        scale_b=weight_scale,
                        bias=bias,
                    )
                else:
                    assert (
                        weight_scale.numel() == weight.shape[1]
                    ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
                    output = fp8_scaled_mm(
                        qinput,
                        weight,
                        x_scale,
                        weight_scale,
                        out_dtype=input.dtype,
                        bias=bias,
                    )
                return output.view(*output_shape)
            except (ImportError, NameError, AttributeError):
                pass

        # torch.scaled_mm supports per tensor weights + activations only
        # so fallback to naive if per channel or per token
        per_tensor_weights = weight_scale.numel() == 1
        per_tensor_activations = x_scale.numel() == 1

        if per_tensor_weights and per_tensor_activations:
            # Fused GEMM_DQ
            output = torch._scaled_mm(
                qinput,
                weight,
                out_dtype=input.dtype,
                scale_a=x_scale,
                scale_b=weight_scale,
                bias=bias,
            )
            return _process_scaled_mm_output(output, input_2d.shape, output_shape)

        else:
            # Fallback for channelwise case, where we use unfused DQ
            # due to limitations with scaled_mm

            # Symmetric quantized GEMM by definition computes the following:
            #   C = (s_x * X) (s_w * W) + bias
            # This is equivalent to dequantizing the weights and activations
            # before applying a GEMM.
            #
            # In order to compute quantized operands, a quantized kernel
            # will rewrite the above like so:
            #   C = s_w * s_x * (X * W) + bias
            #
            # For the scaled_mm fallback case, we break this down, since it
            # does not support s_w being a vector.
            return _apply_fallback_scaled_mm(
                qinput,
                weight,
                x_scale,
                weight_scale,
                input_2d.shape,
                output_shape,
                bias,
                input.dtype,
            )