"git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "c312f1756760a7cf66f24833dc2bf27be2e40433"
fake_quant.py 6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""FakeQuant Feature support for nvidia-dlframework-inspect"""

from typing import Optional

import torch

import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
from nvdlfw_inspect.utils import append_parent_docstring


import transformer_engine_torch as tex
from transformer_engine.debug.features.api import TEConfigAPIMapper
from transformer_engine.common.recipe import Format
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.fp8 import _default_sf_compute


def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None):
    """Input tensor is quantized to fp8 and then dequantized."""

    assert tensor.dtype in (
        torch.float,
        torch.float16,
        torch.bfloat16,
    ), "[NVTORCH INSPECT ERROR] Unsupported tensor type."
    assert tensor.is_cuda, "[NVTORCH INSPECT ERROR] Must be a GPU tensor."
    assert fp8_format in {
        "FP8E4M3",
        "FP8E5M2",
        "MXFP8E4M3",
        "MXFP8E5M2",
    }, (
        "[NVTORCH INSPECT ERROR] Only 4 FP8 types: FP8E4M3, FP8E5M2, MXFP8E4M3, MXFP8E5M2 are"
        " supported in TE."
    )
    if fp8_format in ["FP8E4M3", "FP8E5M2"]:
        if fp8_format == "FP8E4M3":
            fp8_max = Format.E4M3.value.max_fwd
            fp8_dtype = tex.DType.kFloat8E4M3
        else:
            fp8_max = Format.E5M2.value.max_fwd
            fp8_dtype = tex.DType.kFloat8E5M2
        amax = tensor.abs().max().float()
        one = torch.ones(1, device=tensor.device)
52
        scale = _default_sf_compute(amax, one, fp8_max, 0)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

        quantizer = Float8Quantizer(scale, amax, fp8_dtype)
    else:
        quantizer = MXFP8Quantizer(fp8_dtype=fp8_format)
    if out is not None:
        out.copy_(quantizer(tensor).dequantize())
        return None
    return quantizer(tensor).dequantize()


@Registry.register_feature(namespace="transformer_engine")
@append_parent_docstring(parent=TEConfigAPIMapper)
class FakeQuant(TEConfigAPIMapper):
    """

    Disables FP8 GEMM. Fake quantizes chosen tensors to FP8 - using per-tensor scaling factor, not delayed scaling - and runs high-precision GEMM.

    .. figure:: ./img/fake_quant.svg
        :align: center

        Fig 1: Comparison of FP8 FPROP GEMM with the same GEMM in BF16 with fake quantization of activation tensor. Green tensors have the same values, but different dtypes.



    Parameters
    ----------

    gemms/gemms_struct: List[str]
        list of gemms to fake quantize

            - fprop
            - dgrad
            - wgrad
    tensors/tensors_struct: List[str]
        list of tensors to fake quantize

            - activation
            - gradient
            - weight
            - output
            - wgrad
            - dgrad

    quant_format: str
        specifies the FP8 format to use:

            - FP8E5M2
            - FP8E4M3

    Example
    -------
    .. code-block:: yaml

        example_fake_quant_fp8:
            enabled: True
            layers:
                layer_types: [transformer_layer.layernorm_mlp.fc1]
            transformer_engine:
                FakeQuant:
                    enabled: True
                    quant_format: FP8E5M2
                    gemms_struct:
                    - gemm: fprop
                        tensors: [activation, weight]
                    - gemm: dgrad
                        tensors: [gradient]
    """

    def _supported_formats(self):
        """Returns formats that one can fake quantize tensor to."""
        return ["FP8E4M3", "FP8E5M2", "MXFP8E4M3", "MXFP8E5M2"]

    @api_method
    def fp8_gemm_enabled(
        self, config, layer_name: str, gemm: str, iteration: int
    ):  # pylint: disable=unused-argument
        """API call responsible for selecting between high-precision and FP8 GEMM execution."""
        return False

    @api_method
    def modify_tensor_enabled(
        self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int
    ):  # pylint: disable=unused-argument
        """API call used to determine whether to run process_tensor() in the forward."""
        return True

    @api_method
    def modify_tensor(
        self,
        config,
        layer_name: str,
        gemm: str,
        tensor_name: str,
        tensor: torch.Tensor,
        iteration: int,
        default_quantizer: Quantizer,
        out: Optional[torch.Tensor] = None,
        dtype: Optional[torch.dtype] = None,
    ):  # pylint: disable=unused-argument
        """API call used to process the tensor."""

        for key in config.keys():
            if key not in ["gemm", "tensor", "quant_format"]:
                raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')

        if "quant_format" not in config:
            raise ValueError(
                f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor:"
                f" quant_format missing for Tensor: {tensor_name} in the config yaml for"
                " FakeQuant feature which is a required field"
            )
        if config["quant_format"] not in self._supported_formats():
            raise ValueError(
                f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor:"
                f" quant_format: {config['quant_format']} for Tensor: {tensor_name} in the config"
                " yaml for FakeQuant feature is not supported"
            )
        debug_api.log_message(
            f"Feature={self.__class__.__name__}, API=process_tensor: {gemm}, {tensor_name}",
            layer_name,
            extra_cachable_args=(gemm, tensor_name),
        )

        quant_format = config["quant_format"]
        q_tensor = fake_quantize(tensor, quant_format, out=out)
        if dtype is not None:
            q_tensor = q_tensor.to(dtype)
        return q_tensor