test_cpu_offloading.py 6.06 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

5
6
import os
from contextlib import nullcontext
7
8
9
10
import pytest
import torch

import transformer_engine.pytorch as te
11
from transformer_engine.common import recipe
12
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
13
14
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from utils import ModelConfig, get_available_attention_backends
15

16
# Check if FP8 is supported
17
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
18
19
20
21
22
23
24
25
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

fp8_recipes = [
    None,  # non-fp8
    # recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet
    recipe.Float8CurrentScaling(),
    recipe.DelayedScaling(),
]
26

27
28
29
30
31
32
33
model_config = {
    "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
34
35
36
37

# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
    "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
    "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
    "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
    "multihead_attention": lambda: te.MultiheadAttention(
        SIZE, NUM_HEADS, params_dtype=torch.bfloat16
    ),
    "transformer_layer": lambda: te.TransformerLayer(
        SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
    ),
54
55
56
57
}


def _get_input():
58
59
60
61
62
63
64
65
66
    return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda()


def _get_fp8_weight_cache_size(models, fp8_recipe):
    """
    Calculate the total FP8 weight cache size (in MB) for a list of models.
    """
    if fp8_recipe is None:
        return 0
67

68
69
70
71
72
    params_bytes = 0
    for model in models:
        for name, param in model.named_parameters():
            if "weight" in name:
                params_bytes += param.numel()
73

74
75
76
77
78
    # One byte for columnwise and one byte for rowwise,
    # hence multiply by 2 and convert to MB
    # there is 1 byte of scale per 32 elements in mxFP8
    factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1
    return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2)
79

80

81
82
def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload):
    tensor = _get_input()
83
    if cpu_offload:
84
85
        offload_context, sync_function = te.get_cpu_offload_context(
            enabled=True,
86
87
            num_layers=len(models) - 1,
            model_layers=len(models),
88
89
90
            offload_activations=True,
            offload_weights=False,
        )
91
92
93
94
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

95
96
97
98
99
100
    for model in models:
        with te.fp8_autocast(
            enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe
        ), offload_context:
            tensor = model(tensor)
        tensor = sync_function(tensor)
101

102
    max_mem_used = torch.cuda.memory_allocated() / (1024**2)
103
104
    torch.cuda.synchronize()

105
106
    tensor.sum().backward()

107
    return max_mem_used
108

109

110
111
112
113
114
115
116
117
118
119
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model_key", model_types.keys())
def test_cpu_offload(fp8_recipe, model_key) -> None:
    """
    We run three configurations:
    (1) No offloading: All activations remain on the GPU between forward and backward passes.
    (2) No offloading (one layer): Only the first layer's activations remain on the GPU between
        forward and backward passes.
    (3) With offloading (all layers): Only the last layer's activations remain on the GPU
        between forward and backward passes, while all other layers are offloaded to the CPU.
120

121
122
123
124
    We expect the memory consumption of configurations (2) and (3) to be similar, with
    the difference being the size of the FP8 cache that is not offloaded to the CPU.
    We also expect this memory consumption to be smaller than in scenario (1).
    """
125
126
127
    import gc

    gc.collect()
128

129
130
    model_cls = model_types[model_key]
    models_list = [model_cls() for _ in range(NUM_LAYERS)]
131

132
133
134
135
136
137
    if fp8_recipe and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_recipe is not None:
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)

138
139
140
141
142
143
144
145
146
147
148
149
    if model_key in ["multihead_attention", "transformer_layer"]:
        available_backends, *_ = get_available_attention_backends(
            model_config["small"],
            qkv_dtype=torch.bfloat16,
            qkv_layout="sbhd_sbhd_sbhd",
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("Fused attention backend not available.")
        os.environ["NVTE_FLASH_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True

150
151
152
153
154
155
156
    without_offloading = _measure_memory_between_forward_and_backward(
        models_list, fp8_recipe, False
    )
    without_offloading_one_layer = _measure_memory_between_forward_and_backward(
        models_list[:1], fp8_recipe, False
    )
    with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True)
157

158
    assert with_offloading < without_offloading
159
160
161
162
163
164
165
166

    # The only difference between the memory consumption of with_offloading
    # and without_offloading_one_layer should be the size of the FP8 weights cache,
    # which is not offloaded to the CPU.
    memory_consumption_diff = abs(with_offloading - without_offloading_one_layer)
    assert (
        memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON
    )