test_cpu_offloading_v1.py 7.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import contextlib
import gc
import os
from typing import Iterable, Optional

import pytest
import torch

import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends

# Check supported quantization schemes
fp8_available = te.is_fp8_available()
mxfp8_available = te.is_mxfp8_available()

quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available:
    quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))

model_config = {
    "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps

# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"

# CPU offload v1 code path is enabled
assert os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1"

# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
    "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
    "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
    "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
    "multihead_attention": lambda: te.MultiheadAttention(
        SIZE, NUM_HEADS, params_dtype=torch.bfloat16
    ),
    "transformer_layer": lambda: te.TransformerLayer(
        SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
    ),
    "linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
    "layernorm_mlp_ops": lambda: te.ops.Sequential(
        te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
        te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
        te.ops.GELU(),
        te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
    ),
}


def _make_input() -> torch.Tensor:
    """Generate random input tensor."""
    return torch.randn(
        (128, SIZE, SIZE),
        dtype=torch.bfloat16,
        device="cuda",
        requires_grad=True,
    )


def _warmup_model(
    modules: Iterable[torch.nn.Module],
    quantization_recipe: Optional[recipe.Recipe],
) -> None:
    """Perform forward and backward pass"""
    tensor = _make_input()
    for module in modules:
        with te.autocast(
            enabled=quantization_recipe is not None,
            recipe=quantization_recipe,
        ):
            tensor = module(tensor)
    tensor.sum().backward()


def _estimate_cached_weight_size(
    model_name: str,
    modules: Iterable[torch.nn.Module],
    quantization_recipe: Optional[recipe.Recipe],
) -> float:
    """Calculate the memory (in MiB) needed for weight caching."""

    # The weight params are cached directly for unquantized compute
    if quantization_recipe is None:
        return 0

    # Count number of weight param elements
    param_elements = 0
    for module in modules:
        for param in module.parameters():
            if param.dim() == 2:
                param_elements += param.numel()

    # FP8 tensor-scaling caches one byte per element
    if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
        if not is_non_tn_fp8_gemm_supported() and model_name not in (
            "linear_op",
            "layernorm_mlp_ops",
        ):
            # Modules do not deallocate FP8 transpose for weights
            return 2 * param_elements / 1024**2
        return param_elements / 1024**2

    # MXFP8 caches one data byte per element and one scale byte per 32
    # elements
    if quantization_recipe.mxfp8():
        if model_name not in ("linear_op", "layernorm_mlp_ops"):
            # Modules do not deallocate column-wise MXFP8 data for weights
            return 2 * param_elements * (1 + 1 / 32) / 1024**2
        return param_elements * (1 + 1 / 32) / 1024**2

    raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")


def _measure_cached_memory(
    modules: Iterable[torch.nn.Module],
    quantization_recipe: Optional[recipe.Recipe],
    cpu_offload: bool,
) -> float:
    """Measure the growth in allocated GPU memory in MiB after a model forward pass.

    Memory measurement excludes the input and output tensors.

    """

    # Reset memory
    gc.collect()
    torch.cuda.empty_cache()

    # Context and sync function for CPU offloading
    if cpu_offload:
        offload_context, sync_function = te.get_cpu_offload_context(
            enabled=True,
            num_layers=len(modules),
            model_layers=len(modules) + 1,
            offload_activations=True,
            offload_weights=False,
        )
    else:
        offload_context = contextlib.nullcontext()
        sync_function = lambda x: x

    # Forward pass, with dummy step to trigger offload for last module
    inp = _make_input()
    tensor = inp
    memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
    for module in modules:
        with te.autocast(
            enabled=quantization_recipe is not None, recipe=quantization_recipe
        ), offload_context:
            tensor = module(tensor)
        tensor = sync_function(tensor)
    with offload_context:
        tensor = tensor.clone()
    tensor = sync_function(tensor)
    memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)

    # Backward pass
    tensor.sum().backward()
    torch.cuda.synchronize()

    # Memory usage in MiB
    return memory_after_forward - memory_before_forward


@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
    """Check that CPU offloading runs and has expected memory usage."""

    # Construct model
    modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
    if model_name in ["multihead_attention", "transformer_layer"]:
        available_backends, *_ = get_available_attention_backends(
            model_config["small"],
            qkv_dtype=torch.bfloat16,
            qkv_layout="sbhd_sbhd_sbhd",
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("Fused attention backend not available.")
        os.environ["NVTE_FLASH_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True

    # Warmup
    _warmup_model(modules_list, quantization_recipe)

    # Measure cached memory after forward pass
    memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
    memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)

    # Check for expected memory usage
    assert memory_with_offload < memory_without_offload
    memory_from_cached_weights = _estimate_cached_weight_size(
        model_name,
        modules_list,
        quantization_recipe,
    )
    assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON