quant_utils.py 5.41 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Optional, Union
4
5
6

import torch

7
from vllm.platforms import current_platform
8
9
10
11

# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0
12
FP8_DTYPE = current_platform.fp8_dtype()
13

14
15
16
17
18

def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
    return torch.as_tensor(x, dtype=torch.float32, device='cuda')

def ref_dynamic_per_token_quant(x: torch.tensor,
19
20
                                quant_dtype: torch.dtype,
                                scale_ub: Optional[torch.tensor] = None) \
21
        -> tuple[torch.tensor, torch.tensor]:
22

23
    assert quant_dtype in [torch.int8, FP8_DTYPE]
24
    if scale_ub is not None:
25
        assert quant_dtype == FP8_DTYPE
26

27
28
    qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
            else torch.finfo(quant_dtype)
29
30
31
32
    qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
                                        else qtype_traits.max
    qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
                                        else qtype_traits.min
33
    qtype_max = as_float32_tensor(qtype_traits_max)
34
35
    s_1 = as_float32_tensor(1.0)
    s_512 = as_float32_tensor(512.0)
36
37
38
39
40
41
42
43

    # For fp8, in order to match the cuda kernel output, we have to do exactly
    # the same operations as in the corresponding fp8 kernel to prevent
    # rounding errors.

    # Compute scales
    x_token_max, _ = x.abs().max(dim=-1)
    x_token_max = as_float32_tensor(x_token_max)
44
45
    if scale_ub is not None:
        x_token_max = x_token_max.clamp(max=scale_ub)
46
47
48
    scales = (x_token_max / qtype_max)[:, None]

    # Quant
49
50
51
52
    if quant_dtype == torch.int8:
        iscales = as_float32_tensor(s_1 / scales)
        torch_out = as_float32_tensor(x) * iscales
        torch_out = torch_out.round()
53
54
        torch_out = torch_out.clamp(qtype_traits_min,
                                    qtype_traits_max).to(quant_dtype)
55
    else:
56
        assert quant_dtype == FP8_DTYPE
57
58
59
        min_scaling_factor = s_1 / (qtype_max * s_512)
        scales = scales.clamp(min=min_scaling_factor)
        torch_out = as_float32_tensor(x) / scales
60
61
        torch_out = torch_out.clamp(qtype_traits_min,
                                    qtype_traits_max).to(quant_dtype)
62
63
64
65
66
67
68
69

    return torch_out, scales


# The int8 version is very similar. Incorporate the int8 version, like in
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
70
                    -> tuple[torch.tensor, torch.tensor]:
71

72
    fp8_traits = torch.finfo(FP8_DTYPE)
73
74
75
76
    fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
                                    else fp8_traits.max
    fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
                                    else fp8_traits.min
77
    fp8_max = as_float32_tensor(fp8_traits_max)
78
79
80
81
82
83
84
85
86
87
    one = as_float32_tensor(1.0)

    # For fp8, in order to match the cuda kernel output, we have to do exactly
    # the same operations as in the corresponding fp8 kernel to prevent
    # rounding errors.

    x_max = as_float32_tensor(x.abs().max())
    ref_scale = x_max / fp8_max
    ref_iscale = one / ref_scale
    ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
88
        fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
89
    return ref_out, ref_scale.view((1, ))
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149


def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
                             As: torch.Tensor, Bs: torch.Tensor, block_size,
                             output_dtype):
    """This function performs matrix multiplication with block-wise
    quantization using native torch.
    It is agnostic to the input data type and can be used for both int8 and
    fp8 data types.

    It takes two input tensors `A` and `B` (int8) with scales `As` and
    `Bs` (float32).
    The output is returned in the specified `output_dtype`.
    """
    A = A.to(torch.float32)
    B = B.to(torch.float32)
    assert A.shape[-1] == B.shape[-1]
    assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
    assert len(block_size) == 2
    block_n, block_k = block_size[0], block_size[1]
    assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
    assert A.shape[:-1] == As.shape[:-1]

    M = A.numel() // A.shape[-1]
    N, K = B.shape
    origin_C_shape = A.shape[:-1] + (N, )
    A = A.reshape(M, A.shape[-1])
    As = As.reshape(M, As.shape[-1])
    n_tiles = (N + block_n - 1) // block_n
    k_tiles = (K + block_k - 1) // block_k
    assert n_tiles == Bs.shape[0]
    assert k_tiles == Bs.shape[1]

    C_shape = (M, N)
    C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)

    A_tiles = [
        A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
    ]
    B_tiles = [[
        B[
            j * block_n:min((j + 1) * block_n, N),
            i * block_k:min((i + 1) * block_k, K),
        ] for i in range(k_tiles)
    ] for j in range(n_tiles)]
    C_tiles = [
        C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
    ]
    As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]

    for i in range(k_tiles):
        for j in range(n_tiles):
            a = A_tiles[i]
            b = B_tiles[j][i]
            c = C_tiles[j]
            s = As_tiles[i] * Bs[j][i]
            c[:, :] += torch.matmul(a, b.t()) * s

    C = C.reshape(origin_C_shape).to(output_dtype)
    return C