test_cpu_offloading.py 5.2 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

5
6
import os
from contextlib import nullcontext
7
8
9
10
import pytest
import torch

import transformer_engine.pytorch as te
11
from transformer_engine.common import recipe
12
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
13

14
# Check if FP8 is supported
15
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
16
17
18
19
20
21
22
23
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

fp8_recipes = [
    None,  # non-fp8
    # recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet
    recipe.Float8CurrentScaling(),
    recipe.DelayedScaling(),
]
24
25

SIZE = 512
26
27
28
29
30
31
32
NUM_HEADS = 8
NUM_LAYERS = 5
EPSILON = 0.1

# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
    "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
    "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
    "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
    "multihead_attention": lambda: te.MultiheadAttention(
        SIZE, NUM_HEADS, params_dtype=torch.bfloat16
    ),
    "transformer_layer": lambda: te.TransformerLayer(
        SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
    ),
49
50
51
52
}


def _get_input():
53
54
55
56
57
58
59
60
61
    return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda()


def _get_fp8_weight_cache_size(models, fp8_recipe):
    """
    Calculate the total FP8 weight cache size (in MB) for a list of models.
    """
    if fp8_recipe is None:
        return 0
62

63
64
65
66
67
    params_bytes = 0
    for model in models:
        for name, param in model.named_parameters():
            if "weight" in name:
                params_bytes += param.numel()
68

69
70
71
72
73
    # One byte for columnwise and one byte for rowwise,
    # hence multiply by 2 and convert to MB
    # there is 1 byte of scale per 32 elements in mxFP8
    factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1
    return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2)
74

75

76
77
def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload):
    tensor = _get_input()
78
    if cpu_offload:
79
80
        offload_context, sync_function = te.get_cpu_offload_context(
            enabled=True,
81
82
            num_layers=len(models) - 1,
            model_layers=len(models),
83
84
85
            offload_activations=True,
            offload_weights=False,
        )
86
87
88
89
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

90
91
92
93
94
95
    for model in models:
        with te.fp8_autocast(
            enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe
        ), offload_context:
            tensor = model(tensor)
        tensor = sync_function(tensor)
96

97
    max_mem_used = torch.cuda.memory_allocated() / (1024**2)
98
99
    torch.cuda.synchronize()

100
101
    tensor.sum().backward()

102
    return max_mem_used
103

104

105
106
107
108
109
110
111
112
113
114
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model_key", model_types.keys())
def test_cpu_offload(fp8_recipe, model_key) -> None:
    """
    We run three configurations:
    (1) No offloading: All activations remain on the GPU between forward and backward passes.
    (2) No offloading (one layer): Only the first layer's activations remain on the GPU between
        forward and backward passes.
    (3) With offloading (all layers): Only the last layer's activations remain on the GPU
        between forward and backward passes, while all other layers are offloaded to the CPU.
115

116
117
118
119
    We expect the memory consumption of configurations (2) and (3) to be similar, with
    the difference being the size of the FP8 cache that is not offloaded to the CPU.
    We also expect this memory consumption to be smaller than in scenario (1).
    """
120
121
122
    import gc

    gc.collect()
123

124
125
    model_cls = model_types[model_key]
    models_list = [model_cls() for _ in range(NUM_LAYERS)]
126

127
128
129
130
131
132
133
134
135
136
137
138
139
    if fp8_recipe and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_recipe is not None:
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)

    without_offloading = _measure_memory_between_forward_and_backward(
        models_list, fp8_recipe, False
    )
    without_offloading_one_layer = _measure_memory_between_forward_and_backward(
        models_list[:1], fp8_recipe, False
    )
    with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True)
140

141
    assert with_offloading < without_offloading
142
143
144
145
146
147
148
149

    # The only difference between the memory consumption of with_offloading
    # and without_offloading_one_layer should be the size of the FP8 weights cache,
    # which is not offloaded to the CPU.
    memory_consumption_diff = abs(with_offloading - without_offloading_one_layer)
    assert (
        memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON
    )