test_numerics.py 3.55 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.

import os
import subprocess
from pathlib import Path

import pytest
import torch
11
import transformer_engine.pytorch as te
12
from torch.utils.cpp_extension import IS_HIP_EXTENSION
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""
    Distributed numerics tests

    These tests test the numerical corectness of the TransformerEngine layers.
    Tests are parametrized by the layer and fp8 precision.
    One test consists of running multiple configurations from file run_numerics.py
    Such design is due to the fact the initialization of one test is long
    - 2 processes need to start and load torch and TE. Multiple configurations
    are run in one test - this reduces the initialization overhead.

"""


if torch.cuda.device_count() < 2:
    pytest.skip("Distributed training needs at least 2 GPUs.")

29
30
31
32
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
    return_reason=True
33
)
34
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
35
36
37
38
39
40

TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]


41
def _run_test(quantization):
42
43
44
    test_path = TEST_ROOT / "run_numerics.py"
    test_cmd = LAUNCH_CMD + [str(test_path)]

45
46
    if quantization is not None:
        test_cmd += ["--quantization", quantization]
47

48
49
    result = subprocess.run(test_cmd, env=os.environ, check=False)
    assert result.returncode == 0
50
51
52
53


all_boolean = [True, False]

54
55
56
@pytest.mark.parametrize(
    "quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"]
)
57
58
def test_distributed(quantization):
    if quantization == "fp8" and not fp8_available:
wenjh's avatar
wenjh committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        pytest.skip(reason_for_no_fp8)
    if quantization == "fp8_cs" and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
    if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
        pytest.skip(reason_for_no_fp8_block_scaling)
    if quantization == "nvfp4" and not nvfp4_available:
        pytest.skip(reason_for_no_nvfp4)
    _run_test(quantization)

@pytest.mark.parametrize(
    "quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"]
)
def test_int8_distributed(quantization):
    if quantization == "fp8" and not fp8_available:
75
        pytest.skip(reason_for_no_fp8)
76
    if quantization == "fp8_cs" and not fp8_available:
yuguo's avatar
yuguo committed
77
        pytest.skip(reason_for_no_fp8)
78
79
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
80
81
    if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
        pytest.skip(reason_for_no_fp8_block_scaling)
82
83
    if quantization == "nvfp4" and not nvfp4_available:
        pytest.skip(reason_for_no_nvfp4)
84
85
86
87
    if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
        import importlib
        ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", "None")
        os.environ["NVTE_INT8_SIM_FP8"] = "1"
wenjh's avatar
wenjh committed
88
        importlib.reload(te.fp8)
89
    _run_test(quantization)
90
91
92
93
94
    if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
        if ori_int8_sim_fp8 is None or ori_int8_sim_fp8 == "None":
            os.environ["NVTE_INT8_SIM_FP8"] = "0"
        else:
           del os.environ["NVTE_INT8_SIM_FP8"]
wenjh's avatar
wenjh committed
95
        importlib.reload(te.fp8)