fp8_utils.py 10.6 KB
Newer Older
HandH1998's avatar
HandH1998 committed
1
from typing import List, Optional, Tuple
HAI's avatar
HAI committed
2
3

import torch
HandH1998's avatar
HandH1998 committed
4
5
6

from sglang.srt.layers.quantization.fp8_kernel import (
    per_token_group_quant_fp8,
HandH1998's avatar
HandH1998 committed
7
    static_quant_fp8,
HandH1998's avatar
HandH1998 committed
8
9
    w8a8_block_fp8_matmul,
)
HandH1998's avatar
HandH1998 committed
10
11
12
13
from sglang.srt.utils import (
    get_bool_env_var,
    get_cuda_version,
    get_device_capability,
Lianmin Zheng's avatar
Lianmin Zheng committed
14
    is_cuda,
HandH1998's avatar
HandH1998 committed
15
16
17
    is_hip,
)

Lianmin Zheng's avatar
Lianmin Zheng committed
18
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
19
20

is_hip_ = is_hip()
yigex's avatar
yigex committed
21
22
23
if is_hip_ and get_bool_env_var("CK_MOE"):
    from aiter import gemm_a8w8_blockscale

Lianmin Zheng's avatar
Lianmin Zheng committed
24
_is_cuda = is_cuda()
25
26
if _is_cuda:
    from sgl_kernel import fp8_blockwise_scaled_mm
HAI's avatar
HAI committed
27

HandH1998's avatar
HandH1998 committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8

    if use_vllm_cutlass_w8a8_fp8_kernel:
        from vllm import _custom_ops as ops
    else:
        from sgl_kernel import fp8_scaled_mm


def cutlass_fp8_supported():
    if not _is_cuda:
        return False
    major, minor = get_device_capability()
    cuda_version = get_cuda_version()
    if major >= 9:
        return cuda_version >= (12, 0)
    elif major == 8 and minor == 9:
        return cuda_version >= (12, 4)
    return False

HAI's avatar
HAI committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

def normalize_e4m3fn_to_e4m3fnuz(
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    assert weight.dtype == torch.float8_e4m3fn
    # The bits pattern 10000000(-128) represents zero in e4m3fn
    # but NaN in e4m3fnuz. So here we set it to 0.
    # https://onnx.ai/onnx/technical/float8.html
    weight_as_int8 = weight.view(torch.int8)
    ROCM_FP8_NAN_AS_INT = -128
    weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
    weight = weight_as_int8.view(torch.float8_e4m3fnuz)

    # For the same bits representation, e4m3fnuz value is half of
    # the e4m3fn value, so we should double the scaling factor to
    # get the same dequantized value.
    # https://onnx.ai/onnx/technical/float8.html
    weight_scale = weight_scale * 2.0
    if input_scale is not None:
        input_scale = input_scale * 2.0
    return weight, weight_scale, input_scale
HandH1998's avatar
HandH1998 committed
70
71


72
def cutlass_block_fp8_supported() -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
73
    if get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
74
        return False
75
76
77
78
79
80
81
82
83
84
85
86
    if _is_cuda:
        major, minor = torch.cuda.get_device_capability()
        sm_version = major * 10 + minor
        cuda_version = tuple(map(int, torch.version.cuda.split(".")))
        if cuda_version >= (12, 0) and sm_version >= 90:
            return True
    return False


CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()


HandH1998's avatar
HandH1998 committed
87
88
89
90
91
92
93
94
95
96
97
98
def apply_w8a8_block_fp8_linear(
    input: torch.Tensor,
    weight: torch.Tensor,
    block_size: List[int],
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    assert input_scale is None
    # View input as 2D matrix for fp8 methods
    input_2d = input.view(-1, input.shape[-1])
    output_shape = [*input.shape[:-1], weight.shape[0]]
99
100
101
    # TODO: add more robust shape check here
    shape_supported_by_cutlass = (
        weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
HandH1998's avatar
HandH1998 committed
102
    )
103
104
105
106
107
108
109
    if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
        q_input, x_scale = per_token_group_quant_fp8(
            input_2d, block_size[1], column_major_scales=True
        )
        output = fp8_blockwise_scaled_mm(
            q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
        )
yigex's avatar
yigex committed
110
111
112
113
114
115
116
117
118
119
    elif is_hip_ and get_bool_env_var("CK_MOE"):
        q_input, x_scale = per_token_group_quant_fp8(
            input_2d, block_size[1], column_major_scales=False
        )
        output = torch.zeros(
            [q_input.shape[0], weight.shape[0]],
            dtype=input.dtype,
            device=q_input.device,
        )
        gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
120
121
122
123
124
125
126
    else:
        q_input, x_scale = per_token_group_quant_fp8(
            input_2d, block_size[1], column_major_scales=False
        )
        output = w8a8_block_fp8_matmul(
            q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
        )
HandH1998's avatar
HandH1998 committed
127
128
129
130
131
132
133
134
135
136
137
138
139

    if bias is not None:
        output = output + bias
    return output.to(dtype=input.dtype).view(*output_shape)


def input_to_float8(
    x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
) -> Tuple[torch.Tensor, torch.Tensor]:
    """This function quantizes input values to float8 values with tensor-wise quantization."""
    finfo = torch.finfo(dtype)
    min_val, max_val = x.aminmax()
    amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
140
141
142
143
144
    fp8_max = finfo.max
    if is_hip_:
        fp8_max = 224.0
    scale = fp8_max / amax
    x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
HandH1998's avatar
HandH1998 committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()


def block_quant_to_tensor_quant(
    x_q_block: torch.Tensor,
    x_s: torch.Tensor,
    block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """This function converts block-wise quantization to tensor-wise quantization.
    The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
    and the block size.
    The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
    Note only float8 is supported for now.
    """
    block_n, block_k = block_size[0], block_size[1]
    n, k = x_q_block.shape
    n_tiles = (n + block_n - 1) // block_n
    k_tiles = (k + block_k - 1) // block_k
    assert n_tiles == x_s.shape[0]
    assert k_tiles == x_s.shape[1]

    x_dq_block = x_q_block.to(torch.float32)

    x_dq_block_tiles = [
        [
            x_dq_block[
                j * block_n : min((j + 1) * block_n, n),
                i * block_k : min((i + 1) * block_k, k),
            ]
            for i in range(k_tiles)
        ]
        for j in range(n_tiles)
    ]

    for i in range(k_tiles):
        for j in range(n_tiles):
            x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]

    x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
    return x_q_tensor, scale


HandH1998's avatar
HandH1998 committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def apply_fp8_linear(
    input: torch.Tensor,
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
    input_scale_ub: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
    cutlass_fp8_supported: bool = True,
    use_per_token_if_dynamic: bool = False,
) -> torch.Tensor:
    # View input as 2D matrix for fp8 methods
    input_2d = input.view(-1, input.shape[-1])
    output_shape = [*input.shape[:-1], weight.shape[1]]

    # cutlass w8a8 fp8 sgl-kernel only supports per-token scale
    if input_scale is not None:
        assert input_scale.numel() == 1
        # broadcast per-tensor scale to per-token scale when supporting cutlass
        qinput, x_scale = static_quant_fp8(
            input_2d, input_scale, repeat_scale=cutlass_fp8_supported
        )
    else:
        # default use per-token quantization if dynamic
        if _is_cuda:
            qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
        else:
            qinput, x_scale = per_token_group_quant_fp8(
                input_2d, group_size=input_2d.shape[1]
            )

    if cutlass_fp8_supported:
        if use_vllm_cutlass_w8a8_fp8_kernel:
            # Fall back to vllm cutlass w8a8 fp8 kernel
            output = ops.cutlass_scaled_mm(
                qinput,
                weight,
                out_dtype=input.dtype,
                scale_a=x_scale,
                scale_b=weight_scale,
                bias=bias,
            )
        else:
            assert (
                weight_scale.numel() == weight.shape[1]
            ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
            output = fp8_scaled_mm(
                qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias
            )
        return output.view(*output_shape)

    # torch.scaled_mm supports per tensor weights + activations only
    # so fallback to naive if per channel or per token
    else:
        per_tensor_weights = weight_scale.numel() == 1
        per_tensor_activations = x_scale.numel() == 1

        if per_tensor_weights and per_tensor_activations:
            # Fused GEMM_DQ
            output = torch._scaled_mm(
                qinput,
                weight,
                out_dtype=input.dtype,
                scale_a=x_scale,
                scale_b=weight_scale,
                bias=bias,
            )
            # A fix for discrepancy in scaled_mm which returns tuple
            # for torch < 2.5 and a single value in torch >= 2.5
            if type(output) is tuple and len(output) == 2:
                output = output[0]

            return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)

        else:
            # Fallback for channelwise case, where we use unfused DQ
            # due to limitations with scaled_mm

            # Symmetric quantized GEMM by definition computes the following:
            #   C = (s_x * X) (s_w * W) + bias
            # This is equivalent to dequantizing the weights and activations
            # before applying a GEMM.
            #
            # In order to compute quantized operands, a quantized kernel
            # will rewrite the above like so:
            #   C = s_w * s_x * (X * W) + bias
            #
            # For the scaled_mm fallback case, we break this down, since it
            # does not support s_w being a vector.

            # Making sure the dummy tensor is on the same device as the weight
            global TORCH_DEVICE_IDENTITY
            if TORCH_DEVICE_IDENTITY.device != weight.device:
                TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)

            # GEMM
            # This computes C = (X * W).
            # Output in fp32 to allow subsequent ops to happen in-place
            output = torch._scaled_mm(
                qinput,
                weight,
                scale_a=TORCH_DEVICE_IDENTITY,
                scale_b=TORCH_DEVICE_IDENTITY,
                out_dtype=torch.float32,
            )
            # A fix for discrepancy in scaled_mm which returns tuple
            # for torch < 2.5 and a single value in torch >= 2.5
            if type(output) is tuple and len(output) == 2:
                output = output[0]
            # Unpad (undo num_token_padding)
            output = torch.narrow(output, 0, 0, input_2d.shape[0])
            x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
HandH1998's avatar
HandH1998 committed
298

HandH1998's avatar
HandH1998 committed
299
300
301
302
303
304
            # DQ
            # C = sw * sx * (X * W) + bias
            output = output * x_scale * weight_scale.t()
            if bias is not None:
                output = output + bias
            return output.to(dtype=input.dtype).view(*output_shape)