quant_utils.py 3.44 KB
Newer Older
1
from typing import Optional, Tuple, Union
2
3
4

import torch

5
from vllm.platforms import current_platform
6
7
8
9

# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0
10
11
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \
                else torch.float8_e4m3fn
12

13
14
15
16
17

def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
    return torch.as_tensor(x, dtype=torch.float32, device='cuda')

def ref_dynamic_per_token_quant(x: torch.tensor,
18
19
                                quant_dtype: torch.dtype,
                                scale_ub: Optional[torch.tensor] = None) \
20
21
        -> Tuple[torch.tensor, torch.tensor]:

22
    assert quant_dtype in [torch.int8, FP8_DTYPE]
23
    if scale_ub is not None:
24
        assert quant_dtype == FP8_DTYPE
25

26
27
    qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
            else torch.finfo(quant_dtype)
28
29
30
31
    qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
                                        else qtype_traits.max
    qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
                                        else qtype_traits.min
32
    qtype_max = as_float32_tensor(qtype_traits_max)
33
34
    s_1 = as_float32_tensor(1.0)
    s_512 = as_float32_tensor(512.0)
35
36
37
38
39
40
41
42

    # For fp8, in order to match the cuda kernel output, we have to do exactly
    # the same operations as in the corresponding fp8 kernel to prevent
    # rounding errors.

    # Compute scales
    x_token_max, _ = x.abs().max(dim=-1)
    x_token_max = as_float32_tensor(x_token_max)
43
44
    if scale_ub is not None:
        x_token_max = x_token_max.clamp(max=scale_ub)
45
46
47
    scales = (x_token_max / qtype_max)[:, None]

    # Quant
48
49
50
51
    if quant_dtype == torch.int8:
        iscales = as_float32_tensor(s_1 / scales)
        torch_out = as_float32_tensor(x) * iscales
        torch_out = torch_out.round()
52
53
        torch_out = torch_out.clamp(qtype_traits_min,
                                    qtype_traits_max).to(quant_dtype)
54
    else:
55
        assert quant_dtype == FP8_DTYPE
56
57
58
        min_scaling_factor = s_1 / (qtype_max * s_512)
        scales = scales.clamp(min=min_scaling_factor)
        torch_out = as_float32_tensor(x) / scales
59
60
        torch_out = torch_out.clamp(qtype_traits_min,
                                    qtype_traits_max).to(quant_dtype)
61
62
63
64
65
66
67
68
69
70

    return torch_out, scales


# The int8 version is very similar. Incorporate the int8 version, like in
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
                    -> Tuple[torch.tensor, torch.tensor]:

71
    fp8_traits = torch.finfo(FP8_DTYPE)
72
73
74
75
    fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
                                    else fp8_traits.max
    fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
                                    else fp8_traits.min
76
    fp8_max = as_float32_tensor(fp8_traits_max)
77
78
79
80
81
82
83
84
85
86
    one = as_float32_tensor(1.0)

    # For fp8, in order to match the cuda kernel output, we have to do exactly
    # the same operations as in the corresponding fp8 kernel to prevent
    # rounding errors.

    x_max = as_float32_tensor(x.abs().max())
    ref_scale = x_max / fp8_max
    ref_iscale = one / ref_scale
    ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
87
        fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
88
    return ref_out, ref_scale.view((1, ))