fp8.py 1.89 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
7
"""
DEPRECATED in favor of `transformer_engine.pytorch.quantization.py`.
"""
8

9
# pylint: disable=wrong-import-position,unused-import
Przemek Tredak's avatar
Przemek Tredak committed
10

11
12
13
14
15
16
17
18
19
20
21
22
23
24
import warnings

warnings.warn(
    "Using deprecated internal API from Transformer Engine. "
    "transformer_engine.pytorch.fp8 will be removed in a "
    "future release.",
    DeprecationWarning,
    stacklevel=2,
)


# There are some users indirectly importing these classes
# from fp8.py. This ensure backwards compatibility.
# https://github.com/Lightning-AI/lightning-thunder/pull/2635.
25
26
27
28
29
30
from transformer_engine.common.recipe import (
    Recipe,
    DelayedScaling,
    Format,
    MXFP8BlockScaling,
    Float8CurrentScaling,
31
    Float8BlockScaling,
32
    NVFP4BlockScaling,
33
    CustomRecipe,
34
)
Przemek Tredak's avatar
Przemek Tredak committed
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Importing each function instead of 'import *' allows us specify '__all__' in
# quantize.py and also makes any newer additions to quantize.py invisible via
# fp8.py so that we don't reinforce importing internal TE functions.
from .quantization import (
    check_fp8_support,
    check_mxfp8_support,
    check_nvfp4_support,
    check_fp8_block_scaling_support,
    check_recipe_support,
    get_default_fp8_recipe,
    get_fp8_torch_dtype,
    get_fp8_te_dtype,
    get_fp4_te_dtype,
    get_fp8_max,
    FP8GlobalStateManager,
    fp8_model_init,
    fp8_autocast,
    _update_amax_history,
    _default_get_amax_and_update_history,
    _default_sf_compute,
    _compute_amax_and_update_history,
    _compute_scaling_factor,
    _amax_and_scale_update,
    split_and_copy,
    RecipeState,
    DelayedScalingRecipeState,
    Float8CurrentScalingRecipeState,
    MXFP8BlockScalingRecipeState,
    Float8BlockScalingRecipeState,
    NVFP4BlockScalingRecipeState,
    CustomRecipeState,
wenjh's avatar
wenjh committed
67
68
69
    int8_simulation_fp8,
    int8_simulation_fp8_tensorwise,
    blockwise_fp8_block_len
70
)
71