test_cpu_offloading.py 7.33 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

5
6
import contextlib
import gc
7
import os
8
9
from typing import Iterable, Optional

10
11
12
13
import pytest
import torch

import transformer_engine.pytorch as te
14
from transformer_engine.common import recipe
15
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
16
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
17
from utils import ModelConfig, get_available_attention_backends
18

19
# Check supported quantization schemes
20
21
fp8_available = te.is_fp8_available()
mxfp8_available = te.is_mxfp8_available()
22

23
quantization_recipes: Optional[recipe.Recipe] = [None]
24
if fp8_available:
25
    quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
26

27
28
29
30
31
32
33
model_config = {
    "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
34
35
36
37

# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
    "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
    "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
    "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
    "multihead_attention": lambda: te.MultiheadAttention(
        SIZE, NUM_HEADS, params_dtype=torch.bfloat16
    ),
    "transformer_layer": lambda: te.TransformerLayer(
        SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
    ),
54
55
56
57
58
59
60
    "linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
    "layernorm_mlp_ops": lambda: te.ops.Sequential(
        te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
        te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
        te.ops.GELU(),
        te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
    ),
61
62
63
}


64
65
66
67
68
69
70
71
def _make_input() -> torch.Tensor:
    """Generate random input tensor."""
    return torch.randn(
        (128, SIZE, SIZE),
        dtype=torch.bfloat16,
        device="cuda",
        requires_grad=True,
    )
72
73


74
75
76
77
78
79
80
def _warmup_model(
    modules: Iterable[torch.nn.Module],
    quantization_recipe: Optional[recipe.Recipe],
) -> None:
    """Perform forward and backward pass"""
    tensor = _make_input()
    for module in modules:
81
        with te.autocast(
82
            enabled=quantization_recipe is not None,
83
            recipe=quantization_recipe,
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        ):
            tensor = module(tensor)
    tensor.sum().backward()


def _estimate_cached_weight_size(
    model_name: str,
    modules: Iterable[torch.nn.Module],
    quantization_recipe: Optional[recipe.Recipe],
) -> float:
    """Calculate the memory (in MiB) needed for weight caching."""

    # The weight params are cached directly for unquantized compute
    if quantization_recipe is None:
98
        return 0
99

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    # Count number of weight param elements
    param_elements = 0
    for module in modules:
        for param in module.parameters():
            if param.dim() == 2:
                param_elements += param.numel()

    # FP8 tensor-scaling caches one byte per element
    if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
        if not is_non_tn_fp8_gemm_supported() and model_name not in (
            "linear_op",
            "layernorm_mlp_ops",
        ):
            # Modules do not deallocate FP8 transpose for weights
            return 2 * param_elements / 1024**2
        return param_elements / 1024**2

    # MXFP8 caches one data byte per element and one scale byte per 32
    # elements
    if quantization_recipe.mxfp8():
        if model_name not in ("linear_op", "layernorm_mlp_ops"):
            # Modules do not deallocate column-wise MXFP8 data for weights
            return 2 * param_elements * (1 + 1 / 32) / 1024**2
        return param_elements * (1 + 1 / 32) / 1024**2

    raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")


def _measure_cached_memory(
    modules: Iterable[torch.nn.Module],
    quantization_recipe: Optional[recipe.Recipe],
    cpu_offload: bool,
) -> float:
    """Measure the growth in allocated GPU memory in MiB after a model forward pass.

    Memory measurement excludes the input and output tensors.
136

137
    """
138

139
140
141
    # Reset memory
    gc.collect()
    torch.cuda.empty_cache()
142

143
    # Context and sync function for CPU offloading
144
    if cpu_offload:
145
146
        offload_context, sync_function = te.get_cpu_offload_context(
            enabled=True,
147
148
            num_layers=len(modules),
            model_layers=len(modules) + 1,
149
150
151
            offload_activations=True,
            offload_weights=False,
        )
152
    else:
153
        offload_context = contextlib.nullcontext()
154
155
        sync_function = lambda x: x

156
157
158
159
160
    # Forward pass, with dummy step to trigger offload for last module
    inp = _make_input()
    tensor = inp
    memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
    for module in modules:
161
162
        with te.autocast(
            enabled=quantization_recipe is not None, recipe=quantization_recipe
163
        ), offload_context:
164
            tensor = module(tensor)
165
        tensor = sync_function(tensor)
166
167
168
169
    with offload_context:
        tensor = tensor.clone()
    tensor = sync_function(tensor)
    memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
170

171
    # Backward pass
172
    tensor.sum().backward()
173
    torch.cuda.synchronize()
174

175
176
    # Memory usage in MiB
    return memory_after_forward - memory_before_forward
177

178

179
180
181
182
@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
    """Check that CPU offloading runs and has expected memory usage."""
183

184
185
186
    # Construct model
    modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
    if model_name in ["multihead_attention", "transformer_layer"]:
187
188
189
190
191
192
193
194
195
196
197
        available_backends, *_ = get_available_attention_backends(
            model_config["small"],
            qkv_dtype=torch.bfloat16,
            qkv_layout="sbhd_sbhd_sbhd",
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("Fused attention backend not available.")
        os.environ["NVTE_FLASH_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True

198
199
    # Warmup
    _warmup_model(modules_list, quantization_recipe)
200

201
202
203
    # Measure cached memory after forward pass
    memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
    memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
204

205
206
207
208
209
210
    # Check for expected memory usage
    assert memory_with_offload < memory_without_offload
    memory_from_cached_weights = _estimate_cached_weight_size(
        model_name,
        modules_list,
        quantization_recipe,
211
    )
212
    assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON