test_cpu_offloading.py 2.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest
import torch
from contextlib import nullcontext

import transformer_engine.pytorch as te
10
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
11

12
13
14
15
# Check if FP8 supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()

SIZE = 512
16
17
18
19
20
21
22
23
24

models = {
    "linear": te.Linear,
    "layernorm_mlp": te.LayerNormMLP,
    "layernorm_linear": te.LayerNormLinear,
}


def _get_input():
25
    return torch.empty((128, SIZE, SIZE)).cuda()
26
27
28


def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload):
29
30
31
32

    input_layer = model_cls(SIZE, SIZE)
    hidden_layer = model_cls(SIZE, SIZE)
    output_layer = model_cls(SIZE, SIZE)
33
34
35

    input = _get_input()
    if cpu_offload:
36
37
38
39
40
41
42
        offload_context, sync_function = te.get_cpu_offload_context(
            enabled=True,
            num_layers=2,
            model_layers=3,
            offload_activations=True,
            offload_weights=False,
        )
43
44
45
46
47
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

    with te.fp8_autocast(enabled=fp8), offload_context:
48
49
50
51
        out = input_layer(input)
    out = sync_function(out)
    with te.fp8_autocast(enabled=fp8), offload_context:
        out = hidden_layer(out)
52
    out = sync_function(out)
53
54
55
56
57
58
59
60
61
62
63
    with te.fp8_autocast(enabled=fp8), offload_context:
        out = output_layer(out)
    out = sync_function(out)

    max_mem_used = torch.cuda.memory_allocated() / 1024**2

    out.sum().backward()

    del input_layer
    del hidden_layer
    del output_layer
64
65
66
    del input
    del out

67
68
69
    torch.cuda.synchronize()

    return max_mem_used
70

71
72

@pytest.mark.parametrize("fp8", [True, False])
73
74
@pytest.mark.parametrize("model_key", models.keys())
def test_cpu_offload(fp8, model_key) -> None:
75
76
77
78

    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)

79
    model_cls = models[model_key]
80

81
    without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False)
82

83
84
    with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True)

85
    assert with_offloading < without_offloading