test_cpu_offloading.py 7.45 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

5
6
import contextlib
import gc
7
import os
8
9
from typing import Iterable, Optional

10
11
12
13
import pytest
import torch

import transformer_engine.pytorch as te
14
from transformer_engine.common import recipe
15
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
16
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
17
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
18
from utils import ModelConfig, get_available_attention_backends
19

20
# Check supported quantization schemes
21
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
22
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
23

24
quantization_recipes: Optional[recipe.Recipe] = [None]
25
if fp8_available:
26
    quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
27

28
29
30
31
32
33
34
model_config = {
    "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
35
36
37
38

# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
    "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
    "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
    "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
    "multihead_attention": lambda: te.MultiheadAttention(
        SIZE, NUM_HEADS, params_dtype=torch.bfloat16
    ),
    "transformer_layer": lambda: te.TransformerLayer(
        SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
    ),
55
56
57
58
59
60
61
    "linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
    "layernorm_mlp_ops": lambda: te.ops.Sequential(
        te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
        te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
        te.ops.GELU(),
        te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
    ),
62
63
64
}


65
66
67
68
69
70
71
72
def _make_input() -> torch.Tensor:
    """Generate random input tensor."""
    return torch.randn(
        (128, SIZE, SIZE),
        dtype=torch.bfloat16,
        device="cuda",
        requires_grad=True,
    )
73
74


75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def _warmup_model(
    modules: Iterable[torch.nn.Module],
    quantization_recipe: Optional[recipe.Recipe],
) -> None:
    """Perform forward and backward pass"""
    tensor = _make_input()
    for module in modules:
        with te.fp8_autocast(
            enabled=quantization_recipe is not None,
            fp8_recipe=quantization_recipe,
        ):
            tensor = module(tensor)
    tensor.sum().backward()


def _estimate_cached_weight_size(
    model_name: str,
    modules: Iterable[torch.nn.Module],
    quantization_recipe: Optional[recipe.Recipe],
) -> float:
    """Calculate the memory (in MiB) needed for weight caching."""

    # The weight params are cached directly for unquantized compute
    if quantization_recipe is None:
99
        return 0
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    # Count number of weight param elements
    param_elements = 0
    for module in modules:
        for param in module.parameters():
            if param.dim() == 2:
                param_elements += param.numel()

    # FP8 tensor-scaling caches one byte per element
    if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
        if not is_non_tn_fp8_gemm_supported() and model_name not in (
            "linear_op",
            "layernorm_mlp_ops",
        ):
            # Modules do not deallocate FP8 transpose for weights
            return 2 * param_elements / 1024**2
        return param_elements / 1024**2

    # MXFP8 caches one data byte per element and one scale byte per 32
    # elements
    if quantization_recipe.mxfp8():
        if model_name not in ("linear_op", "layernorm_mlp_ops"):
            # Modules do not deallocate column-wise MXFP8 data for weights
            return 2 * param_elements * (1 + 1 / 32) / 1024**2
        return param_elements * (1 + 1 / 32) / 1024**2

    raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")


def _measure_cached_memory(
    modules: Iterable[torch.nn.Module],
    quantization_recipe: Optional[recipe.Recipe],
    cpu_offload: bool,
) -> float:
    """Measure the growth in allocated GPU memory in MiB after a model forward pass.

    Memory measurement excludes the input and output tensors.
137

138
    """
139

140
141
142
    # Reset memory
    gc.collect()
    torch.cuda.empty_cache()
143

144
    # Context and sync function for CPU offloading
145
    if cpu_offload:
146
147
        offload_context, sync_function = te.get_cpu_offload_context(
            enabled=True,
148
149
            num_layers=len(modules),
            model_layers=len(modules) + 1,
150
151
152
            offload_activations=True,
            offload_weights=False,
        )
153
    else:
154
        offload_context = contextlib.nullcontext()
155
156
        sync_function = lambda x: x

157
158
159
160
161
    # Forward pass, with dummy step to trigger offload for last module
    inp = _make_input()
    tensor = inp
    memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
    for module in modules:
162
        with te.fp8_autocast(
163
            enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe
164
        ), offload_context:
165
            tensor = module(tensor)
166
        tensor = sync_function(tensor)
167
168
169
170
    with offload_context:
        tensor = tensor.clone()
    tensor = sync_function(tensor)
    memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
171

172
    # Backward pass
173
    tensor.sum().backward()
174
    torch.cuda.synchronize()
175

176
177
    # Memory usage in MiB
    return memory_after_forward - memory_before_forward
178

179

180
181
182
183
@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
    """Check that CPU offloading runs and has expected memory usage."""
184

185
186
187
    # Construct model
    modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
    if model_name in ["multihead_attention", "transformer_layer"]:
188
189
190
191
192
193
194
195
196
197
198
        available_backends, *_ = get_available_attention_backends(
            model_config["small"],
            qkv_dtype=torch.bfloat16,
            qkv_layout="sbhd_sbhd_sbhd",
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("Fused attention backend not available.")
        os.environ["NVTE_FLASH_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True

199
200
    # Warmup
    _warmup_model(modules_list, quantization_recipe)
201

202
203
204
    # Measure cached memory after forward pass
    memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
    memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
205

206
207
208
209
210
211
    # Check for expected memory usage
    assert memory_with_offload < memory_without_offload
    memory_from_cached_weights = _estimate_cached_weight_size(
        model_name,
        modules_list,
        quantization_recipe,
212
    )
213
    assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON