test_recompute.py 1.82 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tian Zheng's avatar
Tian Zheng committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#
# See LICENSE for license information.
"""Test TE Paddle Recompute"""

from pathlib import Path
import re
import subprocess

import numpy as np
import pytest

from transformer_engine.paddle.fp8 import is_fp8_available

test_root = Path(__file__).resolve().parent
is_fp8_supported, reason = is_fp8_available()


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('use_reentrant', [False, True])
def test_transformer_encoder_recompute(use_reentrant):
    """
    Test TransformerLayer encoder recompute
    """
    rtol = 1e-5
    atol = 1e-5

    def launch_subprocess_and_check_output(enable_recompute):
        """Launch training in subprocess and check output"""
        try:
            cmd = [
                'python',
                str(test_root / 'recompute_tests' / 'recompute_transformer_encoder.py'),
                str(int(enable_recompute)),
                str(int(use_reentrant))
            ]
            result = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True)

            print(result)

            loss_match = re.search(r'Loss:\s+(-?\d+\.\d+)', result)
            memory_match = re.search(r'Peak memory:\s+(\d+)', result)

            loss_value = float(loss_match.group(1))
            memory_value = int(memory_match.group(1))

            return loss_value, memory_value

        except subprocess.CalledProcessError as e:
            raise ValueError(f"Subprocess failed with error: {e}") from e

    loss_recompute, peak_memory_recompute = launch_subprocess_and_check_output(True)
    loss_ref, peak_memory_ref = launch_subprocess_and_check_output(False)

    assert peak_memory_recompute < peak_memory_ref
    np.testing.assert_allclose(loss_recompute, loss_ref, rtol=rtol, atol=atol)