Unverified Commit c5257605 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Activation offloading refactor (#1762)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* offloading
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* all types
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* typo
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* api change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* refactor
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* example
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cpu offload + debug warning
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* test
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change empty_like implementation to use make_like
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* main_grad fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* manual synchornization
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* old path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* remove example
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* api changes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* reverted grouped linear
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* make odl code path work for modules
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* attention old code path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* legacy tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* legacy tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* updated code path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/tensor/quantized_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* nvfp4 support
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update tests/pytorch/test_cpu_offloading.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* small fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* docs change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarroot <root@ptyche0312.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a0754757
......@@ -42,7 +42,8 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
......
......@@ -2,27 +2,41 @@
#
# See LICENSE for license information.
import random
import contextlib
import gc
import os
from typing import Iterable, Optional
import pytest
import os
import torch
from typing import Optional, List
from transformer_engine.pytorch.cpu_offload import (
get_cpu_offload_context,
OffloadableLayerState,
DefaultOffloadSynchronizer,
start_offload,
mark_not_offload,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends
from utils import ModelConfig
import transformer_engine_torch as tex
# Check supported quantization schemes
fp8_available = te.is_fp8_available()
mxfp8_available = te.is_mxfp8_available()
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
nvfp4_available, _ = FP8GlobalStateManager.is_nvfp4_available()
quantization_recipes: Optional[recipe.Recipe] = [None]
quantization_recipes: List[Optional[recipe.Recipe]] = [None]
if fp8_available:
quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
if fp8_block_scaling_available:
quantization_recipes.append(recipe.Float8BlockScaling())
if mxfp8_available:
quantization_recipes.append(recipe.MXFP8BlockScaling())
if nvfp4_available:
quantization_recipes.append(recipe.NVFP4BlockScaling())
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
......@@ -32,181 +46,709 @@ NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
# Disable garbage collection to tests if there are reference cycles.
# We do not want them, because they can result in CUDA out of memory errors.
import gc
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
"linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
"multihead_attention": lambda: te.MultiheadAttention(
SIZE, NUM_HEADS, params_dtype=torch.bfloat16
),
"transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
),
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
"layernorm_mlp_ops": lambda: te.ops.Sequential(
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
),
}
gc.disable()
class Utils:
tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16)
_B = 64
_S = 256
_H = 4
_D = 256
@staticmethod
def long_job(stream: Optional[torch.cuda.Stream] = None):
NUM_ITERS = 6000
if stream is None:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
for i in range(NUM_ITERS):
Utils.tensor1.normal_()
@staticmethod
def measure_time(func):
import time
torch.cuda.synchronize()
start = time.time()
func()
torch.cuda.synchronize()
end = time.time()
return (end - start) * 1000
@staticmethod
def get_cuda_memory_mb():
return torch.cuda.memory_allocated() / (1024**2)
@staticmethod
def get_max_cuda_memory_mb():
return torch.cuda.max_memory_allocated() / (1024**2)
@staticmethod
def get_cpu_memory_mb() -> float:
import psutil, os
return psutil.Process(os.getpid()).memory_info().rss / (1024**2)
@staticmethod
def get_layer_names():
return [
"linear",
"layernorm_linear",
"layernorm_mlp",
"grouped_linear",
"multihead_attention",
"transformer_layer",
"linear_op",
"layernorm_mlp_ops",
]
@staticmethod
def create_layer(layer_type: str):
if layer_type == "linear":
return te.Linear(Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "layernorm_linear":
return te.LayerNormLinear(Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "layernorm_mlp":
return te.LayerNormMLP(Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "multihead_attention":
return te.MultiheadAttention(
Utils._D, Utils._H, attention_dropout=0.0, params_dtype=torch.bfloat16
)
elif layer_type == "grouped_linear":
return te.GroupedLinear(Utils._H, Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "transformer_layer":
return te.TransformerLayer(
Utils._D,
Utils._D,
Utils._H,
attention_dropout=0.0,
hidden_dropout=0.0,
params_dtype=torch.bfloat16,
)
elif layer_type == "linear_op":
return te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16)
elif layer_type == "layernorm_mlp_ops":
return te.ops.Sequential(
te.ops.LayerNorm(Utils._D, dtype=torch.bfloat16),
te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16),
te.ops.GELU(),
te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16),
)
else:
raise ValueError(f"Unknown layer type: {layer_type}")
@staticmethod
def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) -> torch.Tensor:
shape = (Utils._B, Utils._S, Utils._D)
tensor = torch.randn(shape, device="cuda", dtype=torch.bfloat16)
if recipe is None:
tensor = tensor.requires_grad_() if requires_grad else tensor
return tensor
elif recipe.delayed():
quantizer = te.tensor.float8_tensor.Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
scale=torch.tensor([1.0], device="cuda"),
amax=torch.tensor([1.0], device="cuda"),
)
return quantizer(tensor)
elif recipe.float8_current_scaling():
quantizer = te.tensor.float8_tensor.Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"
)
return quantizer(tensor)
elif recipe.float8_block_scaling():
quantizer = te.tensor.float8_blockwise_tensor.Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True
)
return quantizer(tensor)
elif recipe.mxfp8():
quantizer = te.tensor.mxfp8_tensor.MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
return quantizer(tensor)
elif recipe.nvfp4():
quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer()
return quantizer(tensor)
@staticmethod
def create_recipe_ctx(recipe: Optional[recipe.Recipe]):
if recipe is None:
return lambda: contextlib.nullcontext()
else:
return lambda: te.fp8_autocast(fp8_recipe=recipe)
@staticmethod
def get_tensor_size_mb(tensor):
if tensor is None:
return 0
if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage):
return sum(Utils.get_tensor_size_mb(t) for t in tensor.get_data_tensors())
else:
return tensor.numel() * tensor.element_size() / (1024**2)
@staticmethod
def memory_leak_check():
# Should be called before each test.
# Only cublas workspaces and some global tensors are allowed to be allocated.
# All other allocations should be released.
# This is a simple check to catch memory leaks.
if Utils.get_cuda_memory_mb() > 1000:
memory_num = Utils.get_cuda_memory_mb()
import gc
gc.collect() # We want next test to be run with clean state.
gc.disable()
raise RuntimeError(f"Memory leak: {memory_num} MB")
class TestsOffloadableLayerState:
@pytest.mark.parametrize("random_num_tensors", [True, False])
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_general(self, random_num_tensors, recipe):
"""
Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers,
for each layer offload random number of random tensors.
Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors.
"""
Utils.memory_leak_check()
NUM_ITERATIONS = 10
stream = torch.cuda.Stream()
offload_layer_state = OffloadableLayerState(
offload_stream=stream,
)
for _ in range(NUM_ITERATIONS):
original_tensors = []
tensors_ids = []
NUM_TENSORS = random.choice([1, 20]) if random_num_tensors else 1
for _ in range(NUM_TENSORS):
tensor = Utils.create_tensor(recipe)
original_tensors.append(tensor)
tensor_id = offload_layer_state.push_tensor(tensor)
assert tensor.device.type == "cuda"
tensors_ids.append(tensor_id)
offload_layer_state.start_offload()
offload_layer_state.release_activation_forward_gpu_memory()
offload_layer_state.start_reload()
for j in range(len(tensors_ids)):
tensor_gpu = offload_layer_state.pop_tensor(tensors_ids[j])
assert tensor_gpu.device.type == "cuda"
assert tensor_gpu.shape == original_tensors[j].shape
assert tensor_gpu.dtype == original_tensors[j].dtype
torch.testing.assert_close(tensor_gpu, original_tensors[j])
offload_layer_state.release_all_memory()
torch.cuda.synchronize()
def test_offload_base_tensor(self):
Utils.memory_leak_check()
stream = torch.cuda.Stream()
offload_layer_state = OffloadableLayerState(
offload_stream=stream,
)
init_cuda_memory = Utils.get_cuda_memory_mb()
x = Utils.create_tensor(None)
x_size = Utils.get_tensor_size_mb(x)
x_1 = x[::2]
x_2 = x[1::2]
start_offload(x_1, offload_base_tensor=True)
start_offload(x_2, offload_base_tensor=True)
x1_id = offload_layer_state.push_tensor(x_1)
x2_id = offload_layer_state.push_tensor(x_2)
del x_1, x_2
offload_layer_state.start_offload()
offload_layer_state.release_activation_forward_gpu_memory()
assert offload_layer_state.get_offloaded_total_size_mb() == pytest.approx(x_size, 0.1)
offload_layer_state.start_reload()
x_1 = offload_layer_state.pop_tensor(x1_id)
x_2 = offload_layer_state.pop_tensor(x2_id)
assert x_1.device.type == "cuda"
assert x_2.device.type == "cuda"
assert torch.allclose(x_1, x[::2])
assert torch.allclose(x_2, x[1::2])
del x
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + x_size, 0.1)
class TestsDefaultOffloadSynchronizer:
@pytest.mark.parametrize("random_num_tensors", [True, False])
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_general(self, random_num_tensors, recipe):
"""
Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers,
for each layer offload random number of random tensors.
Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors.
"""
Utils.memory_leak_check()
NUM_LAYERS = 10
NUM_ITERATIONS = 10
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=NUM_LAYERS,
num_offloaded_layers=NUM_LAYERS - 1,
)
for _ in range(NUM_ITERATIONS):
original_tensors = []
tensors_ids = []
layer_ids = []
for i in range(NUM_LAYERS):
NUM_LAYER_TENSORS = random.randint(1, 10) if random_num_tensors else 1
layer_tensors = []
layer_tensors_ids = []
layer_id = offload_synchronizer.fwd_step()
for _ in range(NUM_LAYER_TENSORS):
tensor = Utils.create_tensor(recipe)
layer_tensors.append(tensor)
tensor_id = offload_synchronizer.push_tensor(tensor)
assert tensor.device.type == "cuda"
layer_tensors_ids.append(tensor_id)
layer_ids.append(layer_id)
tensors_ids.append(layer_tensors_ids)
original_tensors.append(layer_tensors)
for i in range(NUM_LAYERS - 1, -1, -1):
offload_synchronizer.bwd_step(layer_ids[i])
for j in range(len(tensors_ids[i])):
tensor_gpu = offload_synchronizer.pop_tensor(tensors_ids[i][j])
assert tensor_gpu.device.type == "cuda"
assert tensor_gpu.shape == original_tensors[i][j].shape
assert tensor_gpu.dtype == original_tensors[i][j].dtype
torch.testing.assert_close(tensor_gpu, original_tensors[i][j])
offload_synchronizer.finish_part_of_bwd()
torch.cuda.synchronize()
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_memory(self, recipe):
torch.cuda.synchronize()
Utils.memory_leak_check()
NUM_LAYERS = 10
torch.cuda.reset_peak_memory_stats()
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=NUM_LAYERS,
num_offloaded_layers=NUM_LAYERS - 1,
)
def _make_input() -> torch.Tensor:
"""Generate random input tensor."""
return torch.randn(
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
def _warmup_model(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> None:
"""Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None,
recipe=quantization_recipe,
init_cuda_memory = Utils.get_cuda_memory_mb()
tensor_ids = []
torch.cuda.synchronize()
for _ in range(NUM_LAYERS):
offload_synchronizer.fwd_step()
tensor = Utils.create_tensor(recipe)
tensor_size = Utils.get_tensor_size_mb(tensor)
tensor_id = offload_synchronizer.push_tensor(tensor)
assert tensor.device.type == "cuda"
tensor_ids.append(tensor_id)
del tensor, tensor_id
torch.cuda.synchronize()
if recipe is None:
assert Utils.get_max_cuda_memory_mb() == pytest.approx(
init_cuda_memory + tensor_size, 0.1
)
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + tensor_size, 0.1)
for i in range(NUM_LAYERS - 1, -1, -1):
offload_synchronizer.bwd_step(i)
tensor_gpu = offload_synchronizer.pop_tensor(tensor_ids[i])
assert tensor_gpu.device.type == "cuda"
del tensor_gpu, tensor_ids[i]
offload_synchronizer.finish_part_of_bwd()
del tensor_ids
torch.cuda.synchronize()
if recipe is None:
assert Utils.get_max_cuda_memory_mb() == pytest.approx(
init_cuda_memory + tensor_size, 0.1
)
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1)
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_multiple_tensor_offload(self, recipe):
Utils.memory_leak_check()
init_cpu_memory = Utils.get_cpu_memory_mb()
init_cuda_memory = Utils.get_cuda_memory_mb()
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=2,
num_offloaded_layers=1,
)
x1 = Utils.create_tensor(recipe)
x_size = Utils.get_tensor_size_mb(x1)
offload_synchronizer.fwd_step()
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
offload_synchronizer.fwd_step()
# Only one copy of tensor on cpu is allocated.
assert Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1)
del x1
offload_synchronizer.bwd_step(1)
offload_synchronizer.bwd_step(0)
offload_synchronizer.finish_part_of_bwd()
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1)
class TestTELayers:
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_sanity(self, layer_type, recipe):
Utils.memory_leak_check()
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
):
tensor = module(tensor)
tensor.sum().backward()
def _estimate_cached_weight_size(
model_name: str,
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> float:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if quantization_recipe is None:
return 0
# Count number of weight param elements
param_elements = 0
for module in modules:
for param in module.parameters():
if param.dim() == 2:
param_elements += param.numel()
# FP8 tensor-scaling caches one byte per element
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op",
"layernorm_mlp_ops",
pytest.skip("Fusible operations do not support FP8 block scaling recipe")
recipe_ctx = Utils.create_recipe_ctx(recipe)
init_cuda_memory = Utils.get_cuda_memory_mb()
OFFLOAD_LAYERS = 6
NUM_LAYERS = 10
offload_ctx, sync_function = get_cpu_offload_context(
enabled=True,
num_layers=OFFLOAD_LAYERS,
model_layers=NUM_LAYERS,
)
layers = [Utils.create_layer(layer_type) for _ in range(NUM_LAYERS)]
inp = Utils.create_tensor(None)
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
)
out = inp
for i in range(NUM_LAYERS):
with offload_ctx, recipe_ctx():
# Ops-based layers don't support is_first_microbatch parameter
if layer_type in ["linear_op", "layernorm_mlp_ops"]:
out = layers[i](out, **m_splits)
else:
out = layers[i](out, is_first_microbatch=False, **m_splits)
out = sync_function(out)
out.sum().backward()
torch.cuda.synchronize()
del out, inp, layers
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_memory(self, layer_type, recipe):
Utils.memory_leak_check()
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
):
pytest.skip("Fusible operations do not support FP8 block scaling recipe")
offload_ctx, sync_function = get_cpu_offload_context(
enabled=True,
num_layers=1,
model_layers=2,
offload_activations=True,
offload_weights=False,
)
recipe_ctx = Utils.create_recipe_ctx(recipe)
layer = Utils.create_layer(layer_type)
inp = Utils.create_tensor(None)
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
)
# Ops-based layers don't support is_first_microbatch parameter
is_ops_layer = layer_type in ["linear_op", "layernorm_mlp_ops"]
with recipe_ctx():
if is_ops_layer:
out = layer(inp, **m_splits)
else:
out = layer(inp, is_first_microbatch=True, **m_splits)
out.sum().backward()
del inp
init_cuda_memory = Utils.get_cuda_memory_mb()
# run layer without offload
inp = Utils.create_tensor(None)
with recipe_ctx():
if is_ops_layer:
out = layer(inp, **m_splits)
else:
out = layer(inp, is_first_microbatch=False, **m_splits)
with recipe_ctx():
out = out + 1
del inp
cuda_memory_no_offload = Utils.get_cuda_memory_mb()
out.sum().backward()
# run layer with offload
inp = Utils.create_tensor(None)
with offload_ctx, recipe_ctx():
if is_ops_layer:
out = layer(inp, **m_splits)
else:
out = layer(inp, is_first_microbatch=False, **m_splits)
out = sync_function(out)
with offload_ctx, recipe_ctx():
out = out + 1
out = sync_function(out)
del inp
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1)
offloaded_memory_cpu = offload_ctx.offload_synchronizer.get_offloaded_total_size_mb()
# This assertion verifies that the memory used by tensors on the CPU matches the memory saved from a layer.
# It helps catch cases where an offloaded tensor still has a live pointer, which would
# cause an unnecessary copy to the CPU and prevent GPU memory from being released.
assert Utils.get_cuda_memory_mb() + offloaded_memory_cpu == pytest.approx(
cuda_memory_no_offload, 0.1
)
out.sum().backward()
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_manual_synchronization(self, recipe, layer_type):
Utils.memory_leak_check()
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
):
# Modules do not deallocate FP8 transpose for weights
return 2 * param_elements / 1024**2
return param_elements / 1024**2
pytest.skip("Fusible operations do not support FP8 block scaling recipe")
offload_ctx, sync_function, manual_controller = get_cpu_offload_context(
enabled=True,
model_layers=6,
offload_activations=True,
manual_synchronization=True,
)
layer_1 = Utils.create_layer(layer_type)
layer_2 = Utils.create_layer(layer_type)
inp1 = Utils.create_tensor(None)
inp2 = Utils.create_tensor(None)
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if quantization_recipe.mxfp8():
if model_name not in ("linear_op", "layernorm_mlp_ops"):
# Modules do not deallocate column-wise MXFP8 data for weights
return 2 * param_elements * (1 + 1 / 32) / 1024**2
return param_elements * (1 + 1 / 32) / 1024**2
recipe_ctx = Utils.create_recipe_ctx(recipe)
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
)
init_cuda_memory = Utils.get_cuda_memory_mb()
# 1 fwd
with offload_ctx, recipe_ctx():
out_1 = layer_1(inp1, **m_splits)
out_1 = sync_function(out_1)
with offload_ctx, recipe_ctx():
out_2 = layer_2(inp2, **m_splits)
out_2 = sync_function(out_2)
mark_not_offload(out_1, out_2)
del inp1, inp2
memory_before_offload = Utils.get_cuda_memory_mb()
manual_controller.start_offload_layer(0)
manual_controller.release_activation_forward_gpu_memory(0)
manual_controller.start_offload_layer(1)
manual_controller.release_activation_forward_gpu_memory(1)
memory_after_offload = Utils.get_cuda_memory_mb()
assert memory_after_offload + EPSILON < memory_before_offload
manual_controller.start_reload_layer(0)
manual_controller.start_reload_layer(1)
memory_after_reload = Utils.get_cuda_memory_mb()
assert memory_after_reload == pytest.approx(memory_before_offload, 0.1)
out_1.sum().backward()
out_2.sum().backward()
@pytest.mark.parametrize("recipe", quantization_recipes)
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("use_cuda_graphs", [True, False])
@pytest.mark.parametrize("retain_pinned_cpu_buffers", [True, False])
@pytest.mark.parametrize("backend", ["FlashAttention", "FusedAttention", "UnfusedAttention"])
def test_numerics(
self,
recipe,
layer_type,
use_cuda_graphs,
backend,
retain_pinned_cpu_buffers,
):
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
):
pytest.skip("Fusible operations do not support FP8 block scaling recipe")
recipe_ctx = Utils.create_recipe_ctx(recipe)
def _measure_cached_memory(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
cpu_offload: bool,
) -> float:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
if use_cuda_graphs and not retain_pinned_cpu_buffers:
pytest.skip(
"Cuda graphs are not yet supported with cpu offloading when"
" retain_pinned_cpu_buffers is False."
)
Memory measurement excludes the input and output tensors.
if backend == "FusedAttention" and use_cuda_graphs:
pytest.skip(
"Fused attention + cuda graphs is temporarily broken, not because of cpu offloading"
)
"""
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
# Reset memory
gc.collect()
torch.cuda.empty_cache()
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
elif backend == "UnfusedAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
# Context and sync function for CPU offloading
if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(
offload_ctx, sync_function = get_cpu_offload_context(
enabled=True,
num_layers=len(modules),
model_layers=len(modules) + 1,
num_layers=1,
model_layers=2,
offload_activations=True,
offload_weights=False,
retain_pinned_cpu_buffers=retain_pinned_cpu_buffers,
)
else:
offload_context = contextlib.nullcontext()
sync_function = lambda x: x
# Forward pass, with dummy step to trigger offload for last module
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context:
tensor = module(tensor)
tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
# Backward pass
tensor.sum().backward()
torch.cuda.synchronize()
# Memory usage in MiB
return memory_after_forward - memory_before_forward
@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
"""Check that CPU offloading runs and has expected memory usage."""
# Construct model
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
if model_name in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
qkv_layout="sbhd_sbhd_sbhd",
class Callable(torch.nn.Module):
def __init__(self, offload_ctx=None, sync_function=None):
super().__init__()
self.layers = torch.nn.ModuleList(
[Utils.create_layer(layer_type) for _ in range(2)]
)
self.offload_ctx = offload_ctx
self.sync_function = sync_function
def forward(self, x):
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
)
is_ops_layer = layer_type in ["linear_op", "layernorm_mlp_ops"]
for layer in self.layers:
with self.offload_ctx, recipe_ctx():
if is_ops_layer:
x = layer(x, **m_splits)
else:
x = layer(x, is_first_microbatch=False, **m_splits)
if self.sync_function is not None:
x = self.sync_function(x)
return x
callable_offload = Callable(offload_ctx=offload_ctx, sync_function=sync_function)
callable_no_offload = Callable(offload_ctx=contextlib.nullcontext(), sync_function=None)
# copy parameters
for param_offload, param_no_offload in zip(
callable_offload.parameters(), callable_no_offload.parameters()
):
param_offload.data.copy_(param_no_offload.data)
x = Utils.create_tensor(None)
if use_cuda_graphs:
callable_offload = te.make_graphed_callables(
callable_offload,
(x,),
enabled=recipe is not None,
recipe=(Utils.create_recipe_ctx(recipe) if recipe is not None else None),
)
# warm up (for example to compute sf for delayed scaling)
for _ in range(4):
out = callable_offload(x)
out.sum().backward()
out = callable_no_offload(x)
out.sum().backward()
callable_offload.zero_grad(set_to_none=True)
out_offload = callable_offload(x)
out_offload.sum().backward()
# save out and gradients
offload_outs = [out_offload]
for param in callable_offload.parameters():
offload_outs.append(param.detach().clone())
torch.cuda.reset_peak_memory_stats()
out_no_offload = callable_no_offload(x)
out_no_offload.sum().backward()
# collect gradients
no_offload_outs = [out_no_offload]
for param in callable_no_offload.parameters():
no_offload_outs.append(param.detach().clone())
# check if tensors are the same
for i in range(len(offload_outs)):
assert torch.allclose(offload_outs[i], no_offload_outs[i]), f"Error in tensor {i}."
torch.cuda.synchronize()
def test_example_from_doc(self):
offload_stream = torch.cuda.Stream()
num_layers = 10
layers = [Utils.create_layer("transformer_layer") for _ in range(num_layers)]
inp = [Utils.create_tensor(None) for _ in range(num_layers)]
out = [None] * num_layers
cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
enabled=True,
model_layers=num_layers,
manual_synchronization=True,
offload_stream=offload_stream,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
# Warmup
_warmup_model(modules_list, quantization_recipe)
# Measure cached memory after forward pass
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
# Check for expected memory usage
assert memory_with_offload < memory_without_offload
memory_from_cached_weights = _estimate_cached_weight_size(
model_name,
modules_list,
quantization_recipe,
)
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
for i in range(num_layers):
with cpu_offload_context:
out[i] = layers[i].forward(inp[i])
out[i] = sync_function(out[i])
manual_controller.start_offload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
manual_controller.release_activation_forward_gpu_memory(i)
for i in range(num_layers - 1, -1, -1):
# these calls are intended to be done in the backward pass
manual_controller.start_reload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
out[i].sum().backward()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import contextlib
import gc
import os
from typing import Iterable, Optional
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes
fp8_available = te.is_fp8_available()
mxfp8_available = te.is_mxfp8_available()
quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available:
quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
# CPU offload v1 code path is enabled
assert os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1"
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
"linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
"multihead_attention": lambda: te.MultiheadAttention(
SIZE, NUM_HEADS, params_dtype=torch.bfloat16
),
"transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
),
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
"layernorm_mlp_ops": lambda: te.ops.Sequential(
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
),
}
def _make_input() -> torch.Tensor:
"""Generate random input tensor."""
return torch.randn(
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
def _warmup_model(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> None:
"""Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None,
recipe=quantization_recipe,
):
tensor = module(tensor)
tensor.sum().backward()
def _estimate_cached_weight_size(
model_name: str,
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> float:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if quantization_recipe is None:
return 0
# Count number of weight param elements
param_elements = 0
for module in modules:
for param in module.parameters():
if param.dim() == 2:
param_elements += param.numel()
# FP8 tensor-scaling caches one byte per element
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op",
"layernorm_mlp_ops",
):
# Modules do not deallocate FP8 transpose for weights
return 2 * param_elements / 1024**2
return param_elements / 1024**2
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if quantization_recipe.mxfp8():
if model_name not in ("linear_op", "layernorm_mlp_ops"):
# Modules do not deallocate column-wise MXFP8 data for weights
return 2 * param_elements * (1 + 1 / 32) / 1024**2
return param_elements * (1 + 1 / 32) / 1024**2
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")
def _measure_cached_memory(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
cpu_offload: bool,
) -> float:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
Memory measurement excludes the input and output tensors.
"""
# Reset memory
gc.collect()
torch.cuda.empty_cache()
# Context and sync function for CPU offloading
if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(
enabled=True,
num_layers=len(modules),
model_layers=len(modules) + 1,
offload_activations=True,
offload_weights=False,
)
else:
offload_context = contextlib.nullcontext()
sync_function = lambda x: x
# Forward pass, with dummy step to trigger offload for last module
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context:
tensor = module(tensor)
tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
# Backward pass
tensor.sum().backward()
torch.cuda.synchronize()
# Memory usage in MiB
return memory_after_forward - memory_before_forward
@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
"""Check that CPU offloading runs and has expected memory usage."""
# Construct model
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
if model_name in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
qkv_layout="sbhd_sbhd_sbhd",
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
# Warmup
_warmup_model(modules_list, quantization_recipe)
# Measure cached memory after forward pass
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
# Check for expected memory usage
assert memory_with_offload < memory_without_offload
memory_from_cached_weights = _estimate_cached_weight_size(
model_name,
modules_list,
quantization_recipe,
)
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
......@@ -50,6 +50,13 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
)
from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_activation_offload,
NVTE_CPU_OFFLOAD_V1,
)
from transformer_engine.pytorch.cpu_offload_v1 import is_current_layer_offloaded
# Import attention utils
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
......@@ -737,6 +744,9 @@ class FlashAttention(torch.nn.Module):
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
]
if is_cpu_offload_enabled():
start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True)
# get batch_size, max_seqlen and cu_seqlens
batch_size, context_len = None, None
if inference_params is None:
......@@ -877,12 +887,7 @@ class FlashAttention(torch.nn.Module):
fp8_output=fp8_output,
)
else:
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled,
mark_activation_offload,
)
if CPUOffloadEnabled:
if is_cpu_offload_enabled():
mark_activation_offload(
query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv
)
......@@ -1116,6 +1121,9 @@ class FusedAttnFunc(torch.autograd.Function):
nvtx_label = "transformer_engine.FusedAttnFunc.forward"
nvtx_range_push(f"{nvtx_label}")
if is_cpu_offload_enabled():
start_offload(q, k, v, offload_base_tensor=True)
# recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
......@@ -1293,12 +1301,7 @@ class FusedAttnFunc(torch.autograd.Function):
# used when some tensors are base tensors and loose the "dtype" attribute
ctx.nominal_dtype = out_nominal_dtype
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled,
mark_activation_offload,
)
if CPUOffloadEnabled:
if is_cpu_offload_enabled() and NVTE_CPU_OFFLOAD_V1:
if ctx.fp8:
tensor_list = fp8_tensors
else:
......@@ -1309,6 +1312,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
tensors_to_save, tensor_objects = prepare_for_saving(
*fp8_tensors,
*qkvo_tensors,
......@@ -1339,27 +1343,26 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadedLayer,
)
# If interleaved tensor is offloaded, reloaded tensor will be
# non-interleaved, so we need to modify the QKV layout
# for backward
if CPUOffloadedLayer and CPUOffloadEnabled:
reload_layout = ""
split_list = qkv_layout.split("_")
for split in split_list:
temp_layout = ""
rep_count = 1
for s in split:
if s.isalpha():
temp_layout = temp_layout + s
else:
rep_count = int(s)
for _ in range(rep_count):
reload_layout = reload_layout + temp_layout + "_"
ctx.qkv_layout = reload_layout[:-1]
if NVTE_CPU_OFFLOAD_V1:
# If interleaved tensor is offloaded, reloaded tensor will be
# non-interleaved, so we need to modify the QKV layout
# for backward
if is_current_layer_offloaded() and is_cpu_offload_enabled():
reload_layout = ""
split_list = qkv_layout.split("_")
for split in split_list:
temp_layout = ""
rep_count = 1
for s in split:
if s.isalpha():
temp_layout = temp_layout + s
else:
rep_count = int(s)
for _ in range(rep_count):
reload_layout = reload_layout + temp_layout + "_"
ctx.qkv_layout = reload_layout[:-1]
else:
ctx.qkv_layout = qkv_layout
else:
ctx.qkv_layout = qkv_layout
......
......@@ -1494,14 +1494,6 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_output=fp8_output,
)
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
warnings.warn(
"Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!"
)
if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
if checkpoint_core_attention:
......
......@@ -33,6 +33,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProduc
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled
# Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast().
# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling"
# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa.
......@@ -971,7 +973,8 @@ class MultiheadAttention(torch.nn.Module):
# ===========================
# Core attention computation
# ===========================
if is_cpu_offload_enabled():
start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True)
context_layer = self.core_attention(
query_layer,
key_layer,
......
......@@ -3,698 +3,748 @@
# See LICENSE for license information.
"""Functionality for CPU offloading of tensors saved for backward pass."""
from __future__ import annotations
from contextlib import nullcontext
from typing import Any, Dict, Optional
from __future__ import annotations
import contextlib
from collections import defaultdict
from dataclasses import dataclass, field
import os
import warnings
from typing import Any, Optional
import torch
from torch.autograd.graph import saved_tensors_hooks
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from .quantized_tensor import QuantizedTensorStorage
from .tensor.float8_tensor import Float8Tensor
__all__ = ["get_cpu_offload_context"]
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpu_offload_v1 as v1_code_path
from .quantized_tensor import (
restore_from_saved,
prepare_for_saving,
)
CPUOffloadEnabled = False
CPUOffloadedLayer = False
def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor."""
if TEDebugState.debug_enabled:
raise RuntimeError("CPU offload is not supported in debug mode.")
for tensor in tensors:
if tensor is None:
continue
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
tensor.activation_offloading = True
else:
data_tensors = tensor.get_data_tensors()
for tensor in data_tensors:
if tensor is not None:
tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorStorage classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True
__all__ = ["get_cpu_offload_context", "mark_not_offload", "start_offload"]
NVTE_CPU_OFFLOAD_V1 = os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1"
def is_cpu_offload_enabled() -> bool:
"""Check if CPU offloading is currently enabled."""
return CPUOffloadEnabled
OFFLOAD_SYNCHRONIZER = None
class CpuOffloadSavedTensorHook:
"""Contex-manager that executes a pair of pack/unpack hooks for saved tensors.
def is_cpu_offload_enabled():
"""Returns True if CPU offload is enabled."""
if NVTE_CPU_OFFLOAD_V1:
return v1_code_path.is_cpu_offload_enabled()
return OFFLOAD_SYNCHRONIZER is not None
In this context, the ``on_save_for_backward`` method will be called every time
a tensor is saved for backward (this includes intermediary results saved using
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
also those recorded by a PyTorch-defined operation).
The ``on_get_saved_tensors`` method will be called when the backward function
of this op attempts to retrieve the saved tensor from context (this includes
:func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the
as input the return value of the ``on_save_for_backward``, and is meant to return
an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of
size, device and element values.
def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor."""
if NVTE_CPU_OFFLOAD_V1:
v1_code_path.mark_activation_offload(*tensors)
Example:
>>> import torch
>>> from typing import Any
>>>
>>> class DummyHook(CpuOffloadSavedTensorHook):
...
... def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
... logging.info("On save", tensor)
... return (tensor,)
...
... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
... logging.info("On get", saved_state)
... tensor, = saved_state
... return tensor
...
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with DummyHook():
... y = a * b
...
On save tensor([1., 1., 1., 1., 1.], requires_grad=True)
On save tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),)
On get (tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>),)
def mark_not_offload(*tensors: torch.Tensor):
"""Marks tensors to prevent them from being offloaded."""
if NVTE_CPU_OFFLOAD_V1:
return
"""
tensors, tensor_obj = prepare_for_saving(*tensors)
def __init__(self) -> None:
self.inside_context = False
for tensor in tensors:
if tensor is not None:
setattr(tensor, "_TE_do_not_offload", True)
def __enter__(self):
global CPUOffloadEnabled
CPUOffloadEnabled = True
restore_from_saved(tensor_obj, tensors)
self.inside_context = True
torch._C._autograd._push_saved_tensors_default_hooks(
self.on_save_for_backward, self.on_get_saved_tensor
)
def __exit__(self, *args: Any):
global CPUOffloadEnabled
CPUOffloadEnabled = False
def start_offload(*tensors: torch.Tensor, offload_base_tensor: bool = False):
"""
Marks point in on main stream where tensors are fully computed and ready to be offloaded.
If offload_base_tensor is True and the tensor is a view, the base tensor is offloaded
and reloaded - the stride and storage offset of the view are saved and restored after reload.
It is useful when multiple tensors are views of the same base tensor,
for example in MultiHeadAttention for interleaved q, k, v tensors.
"""
if NVTE_CPU_OFFLOAD_V1:
return
self.inside_context = False
torch._C._autograd._pop_saved_tensors_default_hooks()
def _mark_tensor_for_offload(t):
if t is None:
return
# Attach an event to mark when the tensor is ready for reload.
t.start_reload_event = torch.cuda.Event()
t.start_reload_event.record(torch.cuda.current_stream())
if offload_base_tensor and t._base is not None:
setattr(t, "offload_base_tensor", True)
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
"""On save for backward."""
raise NotImplementedError(
"`on_save_for_backward: Callable[[torch.Tensor], Any]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
tensors, tensor_obj = prepare_for_saving(*tensors)
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
"""On get saved tensor."""
raise NotImplementedError(
"`on_get_saved_tensors: Callable[[Any], torch.Tensor]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
for tensor in tensors:
_mark_tensor_for_offload(tensor)
restore_from_saved(tensor_obj, tensors)
class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
"""Context-manager that offloads/recovers tensors through an offload hander.
The hook just offloads/recovers the tensor object to the handler through `tensor_push`
and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
or prefetching timing is transparent to this hook.
@dataclass
class TensorGroup:
"""
TensorGroup is a collection of tensors, events and auxiliary data.
It is used multiple times in the CPU offload code.
"""
def __init__(
self,
offload_handler: OffloadHandler,
handler_extra_kwargs: Optional[Dict[str, Any]] = None,
debug: bool = False,
) -> None:
if handler_extra_kwargs is None:
handler_extra_kwargs = {}
self.debug: bool = debug
self.offload_handler: OffloadHandler = offload_handler
self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs
super().__init__()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)
return retrieve_identifier
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs)
return tensor
tensor_list: list[torch.Tensor] = field(default_factory=list)
events: list[torch.cuda.Event] = field(default_factory=list)
aux: Any = None
class OffloadHandler:
"""A base class for CPU offload-handler."""
def __init__(self) -> None:
pass
class TensorGroupProcessor:
"""
Suppose there is a tensor group T that needs to be offloaded.
Possibly we can switch T into (T_opt, aux), where T_opt is smaller and easier to offload,
offload T_opt, reload it and then restore T from (T_opt_reloaded, aux).
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
"""Tensor push."""
raise NotImplementedError(
"`tensor_push is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_push."
)
This class contains static methods that perform these optimizations - for example
deduplication of tensors and restoring duplicates after reload.
"""
def tensor_pop(self, tensor_tag: Any, **kwargs):
"""Tensor pop."""
raise NotImplementedError(
"`tensor_pop is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_pop."
)
@staticmethod
def tensor_group_process_before_offload(tensor_group: TensorGroup) -> tuple[TensorGroup, Any]:
"""
Call for a tensor group, just before offloading logic.
aux is a dictionary that contains auxiliary data, needed to restore pre-offload state.
"""
aux = {}
tensor_group = TensorGroupProcessor._switch_to_base_tensors(aux, tensor_group)
tensor_group = TensorGroupProcessor._deduplicate_tensors(aux, tensor_group)
return tensor_group, aux
class GroupCommitFunction(torch.autograd.Function):
"""this is a dummy op with output identical to input.
However, it is necessary for marking a timepoint for offload handler to
accomplish all synchronizations. Implementing it as a function is necessary
because we need to actions in both forward and backward.
"""
@staticmethod
def tensor_group_process_after_reload(tensor_group: TensorGroup):
"""
Call for a tensor group, just after reload logic.
"""
assert tensor_group.aux is not None
tensor_group = TensorGroupProcessor._restore_tensor_duplicates(tensor_group)
tensor_group = TensorGroupProcessor._switch_to_views(tensor_group)
return tensor_group
@staticmethod
def forward(ctx, tensor, cpu_offload_handler):
# pylint: disable=missing-function-docstring
cpu_offload_handler.on_group_commit_forward()
ctx.cpu_offload_handler = cpu_offload_handler
# return the identical tensor
return tensor
def _switch_to_base_tensors(aux, tensor_group: TensorGroup) -> TensorGroup:
"""
Changes tensors to base tensors and saves view options in aux.
It we save multiple tensors which in fact are views of the same base tensor,
this will offload only this one base tensor. It is used for example in
MultiHeadAttention for interleaved q, k, v tensors.
"""
def _check_if_offload_base_tensor(tensor: torch.Tensor) -> bool:
if getattr(tensor, "offload_base_tensor", False):
return True
if tensor._base is not None:
# If tensor is a view of a tensor and has the same elements,
# but with different strides, we can safely offload the base tensor.
# If tensor is a view on some part of a bigger tensor,
# the decision to offload the base tensor is non-trivial and we do not do it by default.
return tensor._base.numel() == tensor.numel()
return False
aux["views"] = []
for tensor_id in range( # pylint: disable=consider-using-enumerate
len(tensor_group.tensor_list)
):
tensor = tensor_group.tensor_list[tensor_id]
if _check_if_offload_base_tensor(tensor):
aux["views"].append((tensor.shape, tensor.stride(), tensor.storage_offset()))
tensor = tensor._base
assert (
tensor is not None
), "Cannot offload base tensor, if the tensor is not a view."
tensor_group.tensor_list[tensor_id] = tensor
else:
aux["views"].append(None)
return tensor_group
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
cpu_offload_handler = ctx.cpu_offload_handler
cpu_offload_handler.on_group_commit_backward()
return grad_output, None
def _deduplicate_tensors(aux, tensor_group: TensorGroup) -> TensorGroup:
"""
Deduplicate tensors.
"""
dedup_tensors: list[torch.Tensor] = []
dedup_events: list[torch.cuda.Event] = []
tensor_to_index: dict[int, int] = {}
aux["original_tensor_ids"] = []
# If there are several duplicates of the same tensor, with different events,
# we keep only first event - every event is recorded when the tensor is ready to be offloaded,
# so it is the most optimal to use the first event.
for tensor_id, tensor in enumerate(tensor_group.tensor_list):
if id(tensor) in tensor_to_index:
aux["original_tensor_ids"].append(tensor_to_index[id(tensor)])
else:
tensor_to_index[id(tensor)] = len(dedup_tensors)
dedup_tensors.append(tensor)
dedup_events.append(tensor_group.events[tensor_id])
aux["original_tensor_ids"].append(tensor_to_index[id(tensor)])
group_prefetch_offload_commit = GroupCommitFunction.apply
tensor_group.tensor_list = dedup_tensors
tensor_group.events = dedup_events
return tensor_group
@staticmethod
def _restore_tensor_duplicates(tensor_group: TensorGroup) -> TensorGroup:
"""
Restore tensor duplicates.
"""
new_tensor_list = []
new_events_list = []
for tensor_id in range(len(tensor_group.aux["original_tensor_ids"])):
original_tensor_id = tensor_group.aux["original_tensor_ids"][tensor_id]
new_tensor_list.append(tensor_group.tensor_list[original_tensor_id])
new_events_list.append(tensor_group.events[original_tensor_id])
tensor_group.tensor_list = new_tensor_list
tensor_group.events = new_events_list
return tensor_group
class SynchronizedGroupOffloadHandler(OffloadHandler):
"""Offload Handler that offloads/reloads in a synchronized way.
The device-to-host and host-to-device copying happen in the same stream
as the computation kernels, thus the copying will block computation.
@staticmethod
def _switch_to_views(tensor_group: TensorGroup) -> TensorGroup:
"""
Switch to views - reverse of _switch_to_base_tensors.
"""
for tensor_id, tensor in enumerate(tensor_group.tensor_list):
if tensor_group.aux["views"][tensor_id] is not None:
tensor_group.tensor_list[tensor_id] = tensor.as_strided(
*tensor_group.aux["views"][tensor_id]
)
return tensor_group
class OffloadableLayerState:
"""
Class that manages offloading and reloading of tensors for a single layer.
"""
def __init__(
self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False
) -> None:
super().__init__()
self.num_offload_group = num_offload_group
self.tensor_need_offloading_checker = tensor_need_offloading_checker
self.debug = debug
self.groupid_reset()
def groupid_reset(self):
"""Groupid reset."""
# Data structures to label saved tensors and book-keep their cpu copies.
# Currently, on push, create a new cpu tensor and copies; on pop, copies
# the tensor back to gpu and deletes the cpu tensor.
# These will increment whenever `group_commit()` is invoked
self.current_group, self.tensor_count_current_group = (0, 0)
self.torch_tensor_count = 0
self.tensor_tag_to_state = {}
def on_group_commit_forward(self):
"""On group commit forward."""
# finishing up with updating current group and tensor count
self.current_group += 1 # increment
self.tensor_count_current_group = 0 # reset
def on_group_commit_backward(self):
"""On group commit backward."""
self.current_group -= 1
assert self.current_group >= 0
@staticmethod
def offload(src_tensor, pin_memory=True):
"""Offload."""
cpu_backup = torch.empty(
src_tensor.size(),
dtype=src_tensor.dtype,
layout=src_tensor.layout,
device="cpu",
pin_memory=pin_memory,
self,
offload_stream: torch.cuda.Stream,
retain_pinned_cpu_buffers: bool = False,
):
self.offload_stream = offload_stream
self.retain_pinned_cpu_buffers = retain_pinned_cpu_buffers
# There are 3 tensor groups: tensors on gpu before offload,
# tensors on cpu after offload, tensors on gpu after reload.
self.fwd_gpu_tensor_group = TensorGroup()
self.cpu_tensor_group = TensorGroup()
self.bwd_gpu_tensor_group = TensorGroup()
self.aux: dict[str, Any] = {}
# State can be one of: not_offloaded, offload_started,
# offload_finished, reload_started.
self.state = "not_offloaded"
def _validate_state(self, func_name: str, allowed_states: list[str]):
assert (
self.state in allowed_states
), f"Invalid state: {self.state} for {func_name}, must be one of {allowed_states}"
def start_offload(self):
"""
Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream.
Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded.
This event is recorded in the start_offload or push_tensor call.
"""
self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"])
self.state = "offload_started"
self.fwd_gpu_tensor_group, aux = TensorGroupProcessor.tensor_group_process_before_offload(
self.fwd_gpu_tensor_group
)
cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
state = (src_tensor.device, cpu_backup)
return state
@staticmethod
def reload(state, non_blocking=None, copy_buffer=None):
"""Reload."""
dev, cpu_backup = state
if non_blocking is None:
non_blocking = cpu_backup.is_pinned()
if copy_buffer is None:
return cpu_backup.to(dev, non_blocking=non_blocking)
assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!"
copy_buffer.copy_(cpu_backup, non_blocking=non_blocking)
allocate_cpu_buffers = (
not self.retain_pinned_cpu_buffers or len(self.cpu_tensor_group.tensor_list) == 0
)
return copy_buffer
for tensor_id, tensor in enumerate(self.fwd_gpu_tensor_group.tensor_list):
assert tensor.is_contiguous()
def tensor_push(self, tensor: torch.Tensor, **kwargs):
"""Tensor push."""
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(
tensor
):
state = SynchronizedGroupOffloadHandler.offload(tensor)
self.tensor_tag_to_state[tensor_tag] = state
else:
# will be offloaded together after group commit
self.tensor_tag_to_state[tensor_tag] = tensor
# Wait for the moment the tensor is ready to be offloaded.
self.offload_stream.wait_event(self.fwd_gpu_tensor_group.events[tensor_id]) # type: ignore[arg-type]
return tensor_tag
with torch.cuda.stream(self.offload_stream):
if allocate_cpu_buffers:
# empty_like is defined also for QuantizedTensors
offloaded_tensor = torch.empty_like(
tensor, device=torch.device("cpu"), pin_memory=True
)
self.cpu_tensor_group.tensor_list.append(offloaded_tensor)
else:
assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, (
"CPU buffer shape does not match the offloaded tensor shape:"
f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} "
" Make sure that tensor shaped do not change between"
" iterations if retain_pinned_cpu_buffers is True."
)
offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id]
offloaded_tensor.copy_(tensor, non_blocking=True)
# aux is a dictionary that contains auxiliary data like information which tensors were deduplicated,
# needed to restore pre-offload state after reload.
self.aux = aux
self.finish_offload_event = torch.cuda.Event()
self.finish_offload_event.record(self.offload_stream)
def release_activation_forward_gpu_memory(self):
"""
Release GPU memory of the activations.
Waits for offload to finish - memory needs to be kept alive when GPU->CPU copy is performed.
"""
self._validate_state(
func_name="release_activation_forward_gpu_memory", allowed_states=["offload_started"]
)
self.state = "offload_finished"
torch.cuda.current_stream().wait_event(self.finish_offload_event) # type: ignore[arg-type]
# GPU memory can be released safely after the offload.
# Notice that the memory needs to be kept alive when GPU->CPU copy is performed.
self.fwd_gpu_tensor_group = TensorGroup()
del self.finish_offload_event
def start_reload(self):
"""
Start reloading of tensors.
It allocates new tensors on GPU and puts copy from CPU tasks on offload stream.
"""
self._validate_state(func_name="start_reload", allowed_states=["offload_finished"])
self.state = "reload_started"
self.bwd_gpu_tensor_group = TensorGroup()
for tensor in self.cpu_tensor_group.tensor_list:
# Notice that reloaded tensor is allocated on main stream,
# not offloaded stream. It is because PyTorch memory allocator
# cannot move tensors from pool of one stream to another without
# calling cudaFree and cudaMalloc again.
# empty_like is defined also for QuantizedTensors.
reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda"))
self.offload_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.offload_stream):
reloaded_tensor.copy_(tensor, non_blocking=True)
reload_tensor_event = torch.cuda.Event()
reload_tensor_event.record(self.offload_stream)
self.bwd_gpu_tensor_group.events.append(reload_tensor_event)
self.bwd_gpu_tensor_group.tensor_list.append(reloaded_tensor)
self.bwd_gpu_tensor_group.aux = self.aux
self.bwd_gpu_tensor_group = TensorGroupProcessor.tensor_group_process_after_reload(
self.bwd_gpu_tensor_group
)
def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
assert tensor_tag in self.tensor_tag_to_state
state = self.tensor_tag_to_state.pop(tensor_tag)
if isinstance(state, tuple):
tensor = SynchronizedGroupOffloadHandler.reload(state)
else:
tensor = state
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
"""
It is called when a tensor is saved for backward pass.
If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group.
If tensor is not offloaded, returns the tensor itself.
"""
self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"])
if self._check_if_offload(tensor):
self.fwd_gpu_tensor_group.tensor_list.append(tensor)
# The group is processed and offloaded at the end of the forward pass of current layer.
# To enable offloading of tensors faster we use self.offload_stream and record
# the events when the tensors are ready to be offloaded.
# It means that we do not need to wait to the end of current layer to start offloading.
if hasattr(tensor, "start_reload_event"):
self.fwd_gpu_tensor_group.events.append(tensor.start_reload_event)
else:
self.fwd_gpu_tensor_group.events.append(torch.cuda.Event())
self.fwd_gpu_tensor_group.events[-1].record(torch.cuda.current_stream())
return len(self.fwd_gpu_tensor_group.tensor_list) - 1
return tensor
def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor:
"""
It is called when a tensor is used in backward pass.
Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish.
"""
self._validate_state(
func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"]
)
class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
"""Compared to synchronize, this uses more memory because of the buffer but
achieves better performance due to the overlapping. D2h and h2d copying are
completely hidden behind computation if computation time of a layer is longer
than host-device communication time. Bulk offloading with delay and bulk reloading
with prefetch are implemented."""
# 1. tensor not offloaded
if isinstance(tensor_or_tensor_id, torch.Tensor):
return tensor_or_tensor_id
# 2. the layer was not offloaded at all
if self.state == "not_offloaded":
return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id]
# 3. the layer was offloaded
assert self.state == "reload_started"
# wait for the tensor to be reloaded
torch.cuda.current_stream().wait_event(
self.bwd_gpu_tensor_group.events[tensor_or_tensor_id]
)
return self.bwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id]
def release_all_memory(self):
"""Release all gpu and cpu memory the state stored. Is called after the backward pass."""
self.fwd_gpu_tensor_group = TensorGroup()
if not self.retain_pinned_cpu_buffers:
self.cpu_tensor_group = TensorGroup()
self.bwd_gpu_tensor_group = TensorGroup()
self.state = "not_offloaded"
def _check_if_offload(self, t: torch.Tensor) -> bool:
"""
Check if tensor needs to be offloaded.
"""
if (
not isinstance(t, torch.nn.Parameter)
and not getattr(t, "_TE_do_not_offload", False)
and not isinstance(t, torch._subclasses.FakeTensor)
and t.device.type == "cuda"
):
if not t.is_contiguous() and not getattr(t, "offload_base_tensor", False):
warnings.warn(
"Tried to offload non-contiguous tensor, which is not supported. Offload of"
" this tensor will be skipped."
)
return False
return True
return False
def get_offloaded_total_size_mb(self) -> float:
"""
Get total size of offloaded tensors in MB, used only for testing.
"""
def get_tensor_size_mb(tensor):
if tensor is None:
return 0
if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage):
return sum(get_tensor_size_mb(t) for t in tensor.get_data_tensors())
return tensor.numel() * tensor.element_size() / (1024**2)
total_size = 0
for tensor in self.cpu_tensor_group.tensor_list:
total_size += get_tensor_size_mb(tensor)
return total_size
class OffloadSynchronizer:
"""
Base class responsible for synchronizing offloading and reloading of tensors for multiple layers.
In base class we only track layer number and
create OffloadableLayerState instances for all layers, but do not start offloading or reloading.
"""
def __init__(
self,
num_offload_group, # must be <= actual number of groups (number of commits)
num_model_group,
tensor_need_offloading_checker=(lambda t: True),
double_buffering=False,
debug=False,
) -> None:
super().__init__(
num_offload_group=num_offload_group,
tensor_need_offloading_checker=tensor_need_offloading_checker,
debug=debug,
)
# Number of layers in the model
self.num_layers = num_model_group
# Data Structure to maintain reference to activation tensors
self.tensor_tag_to_buf = {}
# Data structure to hold the FP8/MXFP8 tensor objects
self.fp8_tensor_object_map = {}
self.float8_transpose_cache_valid = {}
self.dereferencing_list = []
# Tracking the number of layers offloaded
self.offloaded_group_count = 0
# Core data structure that decides the window for offloading
self.layer_window_map = {}
# Data structures fo double buffered reloading
self.double_buffering = double_buffering
self.reload_double_buffer = [[], []]
self.double_buffer_created = False
# Logic to make offloading load balance across computation
# for optimal CPU/GPU interconnect usage
constant = 0
for i in range(self.num_offload_group):
self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1
if i < (self.num_layers % self.num_offload_group):
self.layer_window_map[i] += i + 1
constant = i + 1
else:
self.layer_window_map[i] += constant
# allocate streams and events for synchronization
self.d2h_stream = torch.cuda.Stream()
self.h2d_stream = torch.cuda.Stream()
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
global CPUOffloadedLayer
torch_stray_tensor = isinstance(
tensor,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
num_layers: int,
retain_pinned_cpu_buffers: bool = False,
offload_stream: Optional[torch.cuda.Stream] = None,
):
self.num_layers = num_layers
self.offload_stream = offload_stream if offload_stream is not None else torch.cuda.Stream()
self.layer_states = {
i: OffloadableLayerState(self.offload_stream, retain_pinned_cpu_buffers)
for i in range(num_layers)
}
self.num_of_fwds = None
self.previous_bwd_layer_id = None
self.current_layer_id = None
def fwd_step(self) -> int:
"""
Invoked before each layer forward.
"""
if self.num_of_fwds in [None, self.num_layers - 1]:
# reset the offload synchronizer
self.num_of_fwds = 0
else:
self.num_of_fwds += 1
self.current_layer_id = self.num_of_fwds
return self.current_layer_id
def bwd_step(self, layer_num: int):
"""
Invoked before each layer backward.
"""
if self.previous_bwd_layer_id is not None:
self.layer_states[self.previous_bwd_layer_id].release_all_memory()
self.previous_bwd_layer_id = layer_num
self.current_layer_id = layer_num
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
"""Default push tensor method"""
return self.layer_states[self.num_of_fwds].push_tensor(tensor)
def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor:
"""Default pop tensor method"""
return self.layer_states[self.current_layer_id].pop_tensor(tensor_or_tensor_id)
def finish_part_of_bwd(self):
"""
We need to release memory of backward - this call does that.
It needs to be invoked after every backward pass - there may be
more than one in pipeline parallelism.
It is needed, because call bwd_step is invoked before each layer backward,
but we need to release memory after the backward pass is finished.
"""
if self.previous_bwd_layer_id is not None:
self.layer_states[self.previous_bwd_layer_id].release_all_memory()
self.previous_bwd_layer_id = None
def get_offloaded_total_size_mb(self) -> float:
"""
Get total size of offloaded tensors in MB, used only for testing.
"""
return sum(
self.layer_states[layer_id].get_offloaded_total_size_mb()
for layer_id in self.layer_states
)
is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage)
if not torch_stray_tensor:
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if is_quantized_tensor:
tensor_list, _ = tensor.prepare_for_saving()
self.tensor_tag_to_state[tensor_tag] = []
self.tensor_tag_to_buf[tensor_tag] = []
# Added support for de-duplicating FP8 param tensors
for _, value in self.fp8_tensor_object_map.items():
if tensor is value:
self.dereferencing_list.append(tensor_tag)
break
class DefaultOffloadSynchronizer(OffloadSynchronizer):
"""
Default implementation of OffloadSynchronizer,
intended to be used in standard training workloads - with multiple forwards
and multiple backwards.
"""
self.fp8_tensor_object_map[tensor_tag] = tensor
if isinstance(tensor, Float8Tensor):
self.float8_transpose_cache_valid[tensor_tag] = getattr(
tensor, "_transpose_invalid"
)
def __init__(
self,
num_layers: int,
num_offloaded_layers: int | None = None,
retain_pinned_cpu_buffers: bool = False,
offload_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__(num_layers, retain_pinned_cpu_buffers, offload_stream)
# map of layers to bool meaning if layer needs to be offloaded
self.offload_layer_map: dict[int, bool] = {}
# num_layer: int -> list of layers that need to finish offload by this moment
self.finish_offload_map: defaultdict[int, list[int]] = defaultdict(list)
# num_layer: int -> list of layers that need to start reload in this moment
self.start_reload_map: defaultdict[int, list[int]] = defaultdict(list)
self._init_offload_synchronization_dicts(num_offloaded_layers)
def _init_offload_synchronization_dicts(self, num_offloaded_layers: int):
"""
If synchronization dictionary is not provided, the number of offloaded layers is used to initialize
offload_layer_map, finish_offload_map and start_reload_map.
The aim is to minimize memory usage by the end of the forward pass.
The optimal strategy for that is to offload layers 0, ..., num_offloaded_layers - 1.
For layer i offload needs to finish before num_layers - num_offloaded_layers + i.
For layer i reload needs to start after num_layers - num_offloaded_layers + i.
This ensures that - if all layers have memory footprint of T - then peak memory usage of saving activations is
(num_layers - num_offloaded_layers) * T.
"""
for layer_id in range(self.num_layers):
if layer_id < num_offloaded_layers:
self.offload_layer_map[layer_id] = True
self.finish_offload_map[self.num_layers - num_offloaded_layers + layer_id].append(
layer_id
)
self.start_reload_map[self.num_layers - 1 - num_offloaded_layers + layer_id].append(
layer_id
)
else:
tensor_list = [tensor]
for t in tensor_list:
if is_quantized_tensor:
self.tensor_tag_to_state[tensor_tag].append(t)
else:
self.tensor_tag_to_state[tensor_tag] = t
if (
self.current_group < self.num_offload_group
and self.tensor_need_offloading_checker(t)
):
if is_quantized_tensor:
self.tensor_tag_to_buf[tensor_tag].append(t)
# Need to clear the internal data reference for the quantized tensors
tensor.clear()
else:
self.tensor_tag_to_buf[tensor_tag] = t
# Needed to differentiate non offloaded layer's attention
# QKV layout of attention of non-offloaded layer needs
# to be modified while reloading
CPUOffloadedLayer = True
else:
tensor_tag = (-1, self.torch_tensor_count)
self.torch_tensor_count += 1
self.tensor_tag_to_state[tensor_tag] = tensor
self.offload_layer_map[layer_id] = False
return tensor_tag
def fwd_step(self) -> int:
"""
Invoked before each layer forward.
"""
super().fwd_step()
if self.offload_layer_map.get(self.current_layer_id - 1, False):
self.layer_states[self.current_layer_id - 1].start_offload()
def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
global CPUOffloadedLayer
for layer in self.finish_offload_map[self.current_layer_id]:
self.layer_states[layer].release_activation_forward_gpu_memory()
return self.current_layer_id
assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag)
def bwd_step(self, layer_num: int):
"""
Invoked before each layer backward.
"""
super().bwd_step(layer_num)
# Handling the quantized tensor case specially here
if isinstance(tensor, list):
# If it's a duplicated tensor, we don't need to locally
# write back a tensor as it would already be written
if tensor_tag in self.dereferencing_list:
self.dereferencing_list.remove(tensor_tag)
else:
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
tensor = self.fp8_tensor_object_map.pop(tensor_tag)
for layer in self.start_reload_map[layer_num]:
self.layer_states[layer].start_reload()
if self.double_buffering:
tensor._do_not_clear = True
self.tensor_tag_to_buf.pop(tensor_tag, None)
# the tensor should have been copied back in on_group_commit_backward()
# which invokes bulk_reload_group.
assert not isinstance(tensor, tuple)
return tensor
class ManualOffloadSynchronizer(OffloadSynchronizer):
"""
Manual implementation of OffloadSynchronizer,
all synchronization is done manually by the user by using
one of the following methods:
- start_offload_layer
- release_activation_forward_gpu_memory
- start_reload_layer
This implementation is intended to be used in more complex trainigs workflows.
It is useful for example in pipeline parallelism.
"""
def bulk_offload_group(self, group_to_offload):
"""Bulk offload group."""
with torch.cuda.stream(self.d2h_stream):
for tensor_tag, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_tag
if group_id == group_to_offload:
assert not isinstance(state, tuple)
is_quantized_tensor = isinstance(state, list)
if is_quantized_tensor:
tensor_list = state
self.tensor_tag_to_state[tensor_tag] = []
else:
tensor_list = [state]
for tensor_on_device in tensor_list:
# `tensor_offloaded` is a hacky way of dealing with columnwise-only
# quantized tensors for CPU offloading. The complication is due to
# the `rowwise_data` being `None`. The offloading checker incorrectly
# returns `False` and the entire `state` ([None, columnwise_tensor])
# is added to the tensor tag state dict. A better design would change
# how quantized tensors are kept track of in the offload handler.
# Currently at every stage it is ensured that a quantized tensor is a
# list whereas a non-quantized tensor is standalone object, which is
# not good! TODO(@sanandaraj5597)
tensor_offloaded = False
# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
tensor_offloaded = True
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
if is_quantized_tensor:
if tensor_offloaded:
self.tensor_tag_to_state[tensor_tag].append(state)
else:
self.tensor_tag_to_state[tensor_tag].append(tensor_on_device)
else:
self.tensor_tag_to_state[tensor_tag] = state
def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward."""
global CPUOffloadedLayer
# For the first group, kickstart the offload after we have
# the first compute completion
if current_group == 0:
self.d2h_stream.wait_stream(torch.cuda.current_stream())
if not self.double_buffer_created:
# Creating the first copy of double buffer for tensors that are offloaded
for tensor_tag, buf in self.tensor_tag_to_buf.items():
if isinstance(buf, list):
for b in buf:
self.reload_double_buffer[0].append(
torch.empty_like(b) if self.double_buffering else None
)
else:
self.reload_double_buffer[0].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.bulk_offload_group(current_group)
# Window map data structure helps us synchronize based on number
# of layers offloaded
if self.layer_window_map[self.offloaded_group_count] == current_group:
# Stream synchronization both ways
self.d2h_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().wait_stream(self.d2h_stream)
# Time to free the activation memory after usage
for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorStorage class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
# Release the pointer to the tensor
self.tensor_tag_to_buf[tensor_tag] = None
# Time to offload the next group
if self.offloaded_group_count < (self.num_offload_group - 1):
self.bulk_offload_group(self.offloaded_group_count + 1)
# Increment the offload group count to keep track
self.offloaded_group_count += 1
if current_group == (self.num_offload_group - 1):
CPUOffloadedLayer = False
if not self.double_buffer_created:
# Creating second copy of double buffer for tensors that are offloaded
if current_group == (self.num_layers - 1):
for buf in self.reload_double_buffer[0]:
self.reload_double_buffer[1].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.double_buffer_created = True
def on_group_commit_forward(self):
"""This function will cause host device synchronization"""
# handle synchronization events
self.synchronize_on_group_commit_forward(self.current_group)
super().on_group_commit_forward()
def bulk_reload_group(self, group_to_reload):
"""Bulk reload group."""
assert group_to_reload < self.num_offload_group
buffer_idx = 0
double_buffer_idx = group_to_reload % 2
main_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.h2d_stream):
# move back tensors
for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label
if group_id == group_to_reload:
if isinstance(state, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx]
else:
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state[1], device=torch.cuda.current_device()
)
recovered_tensor = SynchronizedGroupOffloadHandler.reload(
state, True, reload_buffer
)
buffer_idx = buffer_idx + 1
self.tensor_tag_to_state[tensor_label] = recovered_tensor
elif isinstance(state, list):
tensor_list = []
for state_tuple in state:
if isinstance(state_tuple, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][
buffer_idx
]
else:
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state_tuple[1], device=torch.cuda.current_device()
)
tensor_list.append(
SynchronizedGroupOffloadHandler.reload(
state_tuple,
True,
reload_buffer,
)
)
buffer_idx = buffer_idx + 1
else:
tensor_list.append(state_tuple)
# No need to write back the duplicated tensor againn
# to the same location, this check ensures that
if tensor_label in self.dereferencing_list:
self.dereferencing_list.remove(tensor_label)
else:
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(
tensor_list
)
if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor):
self.fp8_tensor_object_map[tensor_label]._transpose_invalid = (
self.float8_transpose_cache_valid.pop(tensor_label)
)
self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop(
tensor_label
)
def on_group_commit_backward(self):
# first decrement the current group.
# after last commit in forward, the group will +1; in backward it -1.
# Finally it should be decremented to 0.
self.current_group -= 1
assert self.current_group >= 0
# Layer window data structure helps us to reload at right times
if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group:
# Stream synchronization both ways
self.h2d_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().wait_stream(self.h2d_stream)
# Time to reload the next group
self.bulk_reload_group(self.offloaded_group_count - 1)
# Decrease the offloading group counter
self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0
# Last group computation needs to wait till all the reloads complete
if self.current_group == 0:
torch.cuda.current_stream().wait_stream(self.h2d_stream)
self.offloaded_group_count = 0
def start_offload_layer(self, layer_id: int):
"""
Start offloading of the layer.
Each tensor GPU->CPU copy is done asynchronously on the offload stream.
Start of each copy is started after tensor_push() is called on the current stream.
"""
self.layer_states[layer_id].start_offload()
def release_activation_forward_gpu_memory(self, layer_id: int):
"""
Release memory of the activations of the layer.
It waits for the offload of the layer to finish.
"""
self.layer_states[layer_id].release_activation_forward_gpu_memory()
def start_reload_layer(self, layer_id: int):
"""
Start reloading of the layer.
Each tensor reload is awaited to finish before tensor_pop() for that tensor is called on the current stream.
"""
self.layer_states[layer_id].start_reload()
def get_cpu_offload_context(
enabled: bool = False,
num_layers: int = 1,
num_layers: Optional[int] = 1,
model_layers: int = 1,
offload_activations: bool = True,
offload_weights: bool = False,
double_buffering: bool = False,
double_buffering: bool = False, # pylint: disable=unused-argument
manual_synchronization: bool = False,
retain_pinned_cpu_buffers: bool = False,
offload_stream: Optional[torch.cuda.Stream] = None,
):
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
used after every transformer layer. Returns `nullcontext()` if offloading is not enabled.
CPU Offloading feature for seqeuences of layers. Can be used for arbitrary layers, not necessarily
for these provided by the TE.
Usage:
.. code-block:: python
cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True)
cpu_offload_context, sync_function = get_cpu_offload_context(...)
with cpu_offload_context:
te_layer.forward(inp_tensor)
cpu_offload_synchronizer()
for _ in range(num_layers):
with cpu_offload_context:
x = layers[i].forward(x)
x = sync_function(x)
Parameters
----------
enabled: bool, default = `False`
When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1
Determines the number of transformer layers
you want to offload activations/weights for.
Determines the number of layers
you want to offload activations/weights for.
model_layers: int, default = 1
Number of layers in the model that will be used under this context.
Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True`
When set to `True`, offloads the activations for the TE layer.
Deprecated.
offload_weights: bool, default = `True`
When set to `True`, offloads the weights for the TE layer.
Deprecated.
double_buffering: bool, default = `False`
When set to `True`, uses double buffering for offloading.
Deprecated.
retain_pinned_cpu_buffers: bool, default = `False`
If True, the pinned CPU buffers are retained after offloading
and reused for the next iteration. It is useful for cuda graphs capture.
manual_synchronization: bool, default = `False`
If True, the synchronization is done manually by the user.
Additional argument manual_controller is returned. See more in manual control section.
offload_stream: torch.cuda.Stream, default = `None`
If provided, the offload stream is used for offloading and reloading.
Otherwise, a new stream is allocated internally. It can be other than None
only if manual_synchronization is True.
Manual synchronization
----------
By default, layers are offloaded/reloaded asynchronously
with respect to the current forward/backward stream with predefined synchronization,
to ensure that activation memory usage is equal to
`(num_layers - num_offloaded_layers) * T`, where `T` is the memory footprint of a layer.
For more control over the offloading and reloading process, you can set `manual_synchronization=True`.
In this case, an additional argument, `manual_controller`, is returned.
The `manual_controller` provides the following methods:
- `start_offload_layer(layer_id: int)`
- `release_activation_forward_gpu_memory(layer_id: int)`
- `start_reload_layer(layer_id: int)`
If none of these methods are invoked for a given layer, that layer will not be offloaded or reloaded.
If `start_offload_layer()` is called for a layer, offload copies for that layer begin asynchronously on the offload stream.
Since GPU activations must be kept in memory until the copy is finished, pointers to all activations are stored.
To release this memory, you need to call `release_activation_forward_gpu_memory(layer_id)`.
This method makes the current stream wait for an event recorded on the offload stream after all tensors from the layer have been offloaded.
The `start_reload_layer()` method is used to start reloading a layer.
Each tensor reload is awaited to finish before `tensor_pop()` for that tensor is called on the current stream.
You can provide an `offload_stream` to be used for offload and reload operations.
This allows for more detailed synchronization, such as delaying the start of offloading.
Example:
.. code-block:: python
offload_stream = torch.cuda.Stream()
cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
enabled=True, model_layers=num_layers, manual_synchronization=True, offload_stream=offload_stream)
for i in range(num_layers):
with cpu_offload_context:
out[i] = layers[i].forward(inp[i])
out[i] = sync_function(out[i])
manual_controller.start_offload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
manual_controller.release_activation_forward_gpu_memory(i)
for i in range(num_layers - 1, -1, -1):
manual_controller.start_reload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
out[i].sum().backward()
V1 code path
----------
If you want to use the v1 code path for offloading,
please set the environment variable NVTE_CPU_OFFLOAD_V1 to 1.
"""
if NVTE_CPU_OFFLOAD_V1:
return v1_code_path.get_cpu_offload_context(
enabled=enabled,
num_layers=num_layers,
model_layers=model_layers,
offload_activations=offload_activations,
offload_weights=offload_weights,
double_buffering=double_buffering,
)
if not offload_weights and not offload_activations:
raise ValueError(
......@@ -703,8 +753,6 @@ def get_cpu_offload_context(
)
if offload_weights:
import warnings
warnings.warn(
"Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect.",
......@@ -713,26 +761,100 @@ def get_cpu_offload_context(
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if not offload_activations:
return nullcontext(), lambda x: x
return contextlib.nullcontext(), lambda x: x
def tensor_need_offloading_checker_activations(tensor):
return hasattr(tensor, "activation_offloading")
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
if TEDebugState.debug_enabled:
raise RuntimeError("CPU offload is not supported in debug mode.")
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers,
num_model_group=model_layers,
tensor_need_offloading_checker=tensor_need_offloading_checker,
double_buffering=double_buffering,
)
if not manual_synchronization:
assert (
num_layers <= model_layers - 1
), "Cannot offload all layers without manual synchronization - last layer is not offloaded."
if num_layers == model_layers - 1:
warnings.warn(
"Offloading num_layers == model_layers - 1 is not recommended, it prevents"
" overlapping of computation and offload/reload."
)
assert (
offload_stream is None or manual_synchronization
), "offload_stream can be provided only if manual_synchronization is True"
if manual_synchronization:
offload_synchronizer = ManualOffloadSynchronizer(
model_layers, retain_pinned_cpu_buffers, offload_stream
)
else:
offload_synchronizer = DefaultOffloadSynchronizer(
model_layers,
num_layers,
retain_pinned_cpu_buffers,
offload_stream,
)
def group_prefetch_offload_commit_async(tensor):
return group_prefetch_offload_commit(tensor, cpu_offload_handler)
class _CpuOffloadContext(contextlib.ContextDecorator):
def __init__(self):
self.current_layer = None
self.previous_offload_synchronizer = None
self.offload_synchronizer = offload_synchronizer
self.inside_context = False
def __enter__(self):
assert (
self.inside_context is False
), "Offloading context was entered without synchronization function being called."
self.inside_context = True
self._hooks_ctx = saved_tensors_hooks(
offload_synchronizer.push_tensor, offload_synchronizer.pop_tensor
)
self._hooks_ctx.__enter__()
global OFFLOAD_SYNCHRONIZER
self.previous_offload_synchronizer = OFFLOAD_SYNCHRONIZER
OFFLOAD_SYNCHRONIZER = offload_synchronizer
self.current_layer = offload_synchronizer.fwd_step()
return self
def __exit__(self, *args):
self._hooks_ctx.__exit__(*args)
global OFFLOAD_SYNCHRONIZER
OFFLOAD_SYNCHRONIZER = self.previous_offload_synchronizer
self.inside_context = False
def synchronization_function(self, tensor):
"""
This function is used to catch the backward pass of the model.
"""
assert tensor.requires_grad is True
assert self.current_layer is not None
cur_layer = self.current_layer
assert (
self.inside_context is False
), "Synchronization function was called without offloading context being entered."
def hook(_):
# offload_synchronizer.finish_part_of_bwd needs
# to be called after every backward pass - there may be
# more than one in pipeline parallelism.
torch.autograd.variable.Variable._execution_engine.queue_callback(
offload_synchronizer.finish_part_of_bwd
)
offload_synchronizer.bwd_step(cur_layer)
tensor.grad_fn.register_prehook(hook)
return tensor
cpu_offload_context = _CpuOffloadContext()
if enabled:
if manual_synchronization:
return (
cpu_offload_context,
cpu_offload_context.synchronization_function,
offload_synchronizer,
)
return (
CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler),
group_prefetch_offload_commit_async,
cpu_offload_context,
cpu_offload_context.synchronization_function,
)
return nullcontext(), group_prefetch_offload_commit_async
return contextlib.nullcontext(), lambda x: x
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Functionality for CPU offloading of tensors saved for backward pass."""
from __future__ import annotations
from contextlib import nullcontext
from typing import Any, Dict, Optional
import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from .quantized_tensor import QuantizedTensorStorage
from .tensor.float8_tensor import Float8Tensor
__all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False
CPUOffloadedLayer = False
def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor."""
if TEDebugState.debug_enabled:
raise RuntimeError("CPU offload is not supported in debug mode.")
for tensor in tensors:
if tensor is None:
continue
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
tensor.activation_offloading = True
else:
data_tensors = tensor.get_data_tensors()
for tensor in data_tensors:
if tensor is not None:
tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorStorage classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True
def is_cpu_offload_enabled() -> bool:
"""Check if CPU offloading is currently enabled."""
return CPUOffloadEnabled
def is_current_layer_offloaded() -> bool:
"""Check if current layers is being offloaded."""
return CPUOffloadedLayer
class CpuOffloadSavedTensorHook:
"""Contex-manager that executes a pair of pack/unpack hooks for saved tensors.
In this context, the ``on_save_for_backward`` method will be called every time
a tensor is saved for backward (this includes intermediary results saved using
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
also those recorded by a PyTorch-defined operation).
The ``on_get_saved_tensors`` method will be called when the backward function
of this op attempts to retrieve the saved tensor from context (this includes
:func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the
as input the return value of the ``on_save_for_backward``, and is meant to return
an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of
size, device and element values.
Example:
>>> import torch
>>> from typing import Any
>>>
>>> class DummyHook(CpuOffloadSavedTensorHook):
...
... def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
... logging.info("On save", tensor)
... return (tensor,)
...
... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
... logging.info("On get", saved_state)
... tensor, = saved_state
... return tensor
...
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with DummyHook():
... y = a * b
...
On save tensor([1., 1., 1., 1., 1.], requires_grad=True)
On save tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),)
On get (tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>),)
"""
def __init__(self) -> None:
self.inside_context = False
def __enter__(self):
global CPUOffloadEnabled
CPUOffloadEnabled = True
self.inside_context = True
torch._C._autograd._push_saved_tensors_default_hooks(
self.on_save_for_backward, self.on_get_saved_tensor
)
def __exit__(self, *args: Any):
global CPUOffloadEnabled
CPUOffloadEnabled = False
self.inside_context = False
torch._C._autograd._pop_saved_tensors_default_hooks()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
"""On save for backward."""
raise NotImplementedError(
"`on_save_for_backward: Callable[[torch.Tensor], Any]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
"""On get saved tensor."""
raise NotImplementedError(
"`on_get_saved_tensors: Callable[[Any], torch.Tensor]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
"""Context-manager that offloads/recovers tensors through an offload hander.
The hook just offloads/recovers the tensor object to the handler through `tensor_push`
and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
or prefetching timing is transparent to this hook.
"""
def __init__(
self,
offload_handler: OffloadHandler,
handler_extra_kwargs: Optional[Dict[str, Any]] = None,
debug: bool = False,
) -> None:
if handler_extra_kwargs is None:
handler_extra_kwargs = {}
self.debug: bool = debug
self.offload_handler: OffloadHandler = offload_handler
self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs
super().__init__()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)
return retrieve_identifier
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs)
return tensor
class OffloadHandler:
"""A base class for CPU offload-handler."""
def __init__(self) -> None:
pass
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
"""Tensor push."""
raise NotImplementedError(
"`tensor_push is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_push."
)
def tensor_pop(self, tensor_tag: Any, **kwargs):
"""Tensor pop."""
raise NotImplementedError(
"`tensor_pop is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_pop."
)
class GroupCommitFunction(torch.autograd.Function):
"""this is a dummy op with output identical to input.
However, it is necessary for marking a timepoint for offload handler to
accomplish all synchronizations. Implementing it as a function is necessary
because we need to actions in both forward and backward.
"""
@staticmethod
def forward(ctx, tensor, cpu_offload_handler):
# pylint: disable=missing-function-docstring
cpu_offload_handler.on_group_commit_forward()
ctx.cpu_offload_handler = cpu_offload_handler
# return the identical tensor
return tensor
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
cpu_offload_handler = ctx.cpu_offload_handler
cpu_offload_handler.on_group_commit_backward()
return grad_output, None
group_prefetch_offload_commit = GroupCommitFunction.apply
class SynchronizedGroupOffloadHandler(OffloadHandler):
"""Offload Handler that offloads/reloads in a synchronized way.
The device-to-host and host-to-device copying happen in the same stream
as the computation kernels, thus the copying will block computation.
"""
def __init__(
self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False
) -> None:
super().__init__()
self.num_offload_group = num_offload_group
self.tensor_need_offloading_checker = tensor_need_offloading_checker
self.debug = debug
self.groupid_reset()
def groupid_reset(self):
"""Groupid reset."""
# Data structures to label saved tensors and book-keep their cpu copies.
# Currently, on push, create a new cpu tensor and copies; on pop, copies
# the tensor back to gpu and deletes the cpu tensor.
# These will increment whenever `group_commit()` is invoked
self.current_group, self.tensor_count_current_group = (0, 0)
self.torch_tensor_count = 0
self.tensor_tag_to_state = {}
def on_group_commit_forward(self):
"""On group commit forward."""
# finishing up with updating current group and tensor count
self.current_group += 1 # increment
self.tensor_count_current_group = 0 # reset
def on_group_commit_backward(self):
"""On group commit backward."""
self.current_group -= 1
assert self.current_group >= 0
@staticmethod
def offload(src_tensor, pin_memory=True):
"""Offload."""
cpu_backup = torch.empty(
src_tensor.size(),
dtype=src_tensor.dtype,
layout=src_tensor.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
state = (src_tensor.device, cpu_backup)
return state
@staticmethod
def reload(state, non_blocking=None, copy_buffer=None):
"""Reload."""
dev, cpu_backup = state
if non_blocking is None:
non_blocking = cpu_backup.is_pinned()
if copy_buffer is None:
return cpu_backup.to(dev, non_blocking=non_blocking)
assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!"
copy_buffer.copy_(cpu_backup, non_blocking=non_blocking)
return copy_buffer
def tensor_push(self, tensor: torch.Tensor, **kwargs):
"""Tensor push."""
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(
tensor
):
state = SynchronizedGroupOffloadHandler.offload(tensor)
self.tensor_tag_to_state[tensor_tag] = state
else:
# will be offloaded together after group commit
self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag
def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
assert tensor_tag in self.tensor_tag_to_state
state = self.tensor_tag_to_state.pop(tensor_tag)
if isinstance(state, tuple):
tensor = SynchronizedGroupOffloadHandler.reload(state)
else:
tensor = state
return tensor
class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
"""Compared to synchronize, this uses more memory because of the buffer but
achieves better performance due to the overlapping. D2h and h2d copying are
completely hidden behind computation if computation time of a layer is longer
than host-device communication time. Bulk offloading with delay and bulk reloading
with prefetch are implemented."""
def __init__(
self,
num_offload_group, # must be <= actual number of groups (number of commits)
num_model_group,
tensor_need_offloading_checker=(lambda t: True),
double_buffering=False,
debug=False,
) -> None:
super().__init__(
num_offload_group=num_offload_group,
tensor_need_offloading_checker=tensor_need_offloading_checker,
debug=debug,
)
# Number of layers in the model
self.num_layers = num_model_group
# Data Structure to maintain reference to activation tensors
self.tensor_tag_to_buf = {}
# Data structure to hold the FP8/MXFP8 tensor objects
self.fp8_tensor_object_map = {}
self.float8_transpose_cache_valid = {}
self.dereferencing_list = []
# Tracking the number of layers offloaded
self.offloaded_group_count = 0
# Core data structure that decides the window for offloading
self.layer_window_map = {}
# Data structures fo double buffered reloading
self.double_buffering = double_buffering
self.reload_double_buffer = [[], []]
self.double_buffer_created = False
# Logic to make offloading load balance across computation
# for optimal CPU/GPU interconnect usage
constant = 0
for i in range(self.num_offload_group):
self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1
if i < (self.num_layers % self.num_offload_group):
self.layer_window_map[i] += i + 1
constant = i + 1
else:
self.layer_window_map[i] += constant
# allocate streams and events for synchronization
self.d2h_stream = torch.cuda.Stream()
self.h2d_stream = torch.cuda.Stream()
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
global CPUOffloadedLayer
torch_stray_tensor = isinstance(
tensor,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
)
is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage)
if not torch_stray_tensor:
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if is_quantized_tensor:
tensor_list, _ = tensor.prepare_for_saving()
self.tensor_tag_to_state[tensor_tag] = []
self.tensor_tag_to_buf[tensor_tag] = []
# Added support for de-duplicating FP8 param tensors
for _, value in self.fp8_tensor_object_map.items():
if tensor is value:
self.dereferencing_list.append(tensor_tag)
break
self.fp8_tensor_object_map[tensor_tag] = tensor
if isinstance(tensor, Float8Tensor):
self.float8_transpose_cache_valid[tensor_tag] = getattr(
tensor, "_transpose_invalid"
)
else:
tensor_list = [tensor]
for t in tensor_list:
if is_quantized_tensor:
self.tensor_tag_to_state[tensor_tag].append(t)
else:
self.tensor_tag_to_state[tensor_tag] = t
if (
self.current_group < self.num_offload_group
and self.tensor_need_offloading_checker(t)
):
if is_quantized_tensor:
self.tensor_tag_to_buf[tensor_tag].append(t)
# Need to clear the internal data reference for the quantized tensors
tensor.clear()
else:
self.tensor_tag_to_buf[tensor_tag] = t
# Needed to differentiate non offloaded layer's attention
# QKV layout of attention of non-offloaded layer needs
# to be modified while reloading
CPUOffloadedLayer = True
else:
tensor_tag = (-1, self.torch_tensor_count)
self.torch_tensor_count += 1
self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag
def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
global CPUOffloadedLayer
assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag)
# Handling the quantized tensor case specially here
if isinstance(tensor, list):
# If it's a duplicated tensor, we don't need to locally
# write back a tensor as it would already be written
if tensor_tag in self.dereferencing_list:
self.dereferencing_list.remove(tensor_tag)
else:
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
tensor = self.fp8_tensor_object_map.pop(tensor_tag)
if self.double_buffering:
tensor._do_not_clear = True
self.tensor_tag_to_buf.pop(tensor_tag, None)
# the tensor should have been copied back in on_group_commit_backward()
# which invokes bulk_reload_group.
assert not isinstance(tensor, tuple)
return tensor
def bulk_offload_group(self, group_to_offload):
"""Bulk offload group."""
with torch.cuda.stream(self.d2h_stream):
for tensor_tag, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_tag
if group_id == group_to_offload:
assert not isinstance(state, tuple)
is_quantized_tensor = isinstance(state, list)
if is_quantized_tensor:
tensor_list = state
self.tensor_tag_to_state[tensor_tag] = []
else:
tensor_list = [state]
for tensor_on_device in tensor_list:
# `tensor_offloaded` is a hacky way of dealing with columnwise-only
# quantized tensors for CPU offloading. The complication is due to
# the `rowwise_data` being `None`. The offloading checker incorrectly
# returns `False` and the entire `state` ([None, columnwise_tensor])
# is added to the tensor tag state dict. A better design would change
# how quantized tensors are kept track of in the offload handler.
# Currently at every stage it is ensured that a quantized tensor is a
# list whereas a non-quantized tensor is standalone object, which is
# not good! TODO(@sanandaraj5597)
tensor_offloaded = False
# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
tensor_offloaded = True
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
if is_quantized_tensor:
if tensor_offloaded:
self.tensor_tag_to_state[tensor_tag].append(state)
else:
self.tensor_tag_to_state[tensor_tag].append(tensor_on_device)
else:
self.tensor_tag_to_state[tensor_tag] = state
def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward."""
global CPUOffloadedLayer
# For the first group, kickstart the offload after we have
# the first compute completion
if current_group == 0:
self.d2h_stream.wait_stream(torch.cuda.current_stream())
if not self.double_buffer_created:
# Creating the first copy of double buffer for tensors that are offloaded
for tensor_tag, buf in self.tensor_tag_to_buf.items():
if isinstance(buf, list):
for b in buf:
self.reload_double_buffer[0].append(
torch.empty_like(b) if self.double_buffering else None
)
else:
self.reload_double_buffer[0].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.bulk_offload_group(current_group)
# Window map data structure helps us synchronize based on number
# of layers offloaded
if self.layer_window_map[self.offloaded_group_count] == current_group:
# Stream synchronization both ways
self.d2h_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().wait_stream(self.d2h_stream)
# Time to free the activation memory after usage
for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorStorage class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
# Release the pointer to the tensor
self.tensor_tag_to_buf[tensor_tag] = None
# Time to offload the next group
if self.offloaded_group_count < (self.num_offload_group - 1):
self.bulk_offload_group(self.offloaded_group_count + 1)
# Increment the offload group count to keep track
self.offloaded_group_count += 1
if current_group == (self.num_offload_group - 1):
CPUOffloadedLayer = False
if not self.double_buffer_created:
# Creating second copy of double buffer for tensors that are offloaded
if current_group == (self.num_layers - 1):
for buf in self.reload_double_buffer[0]:
self.reload_double_buffer[1].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.double_buffer_created = True
def on_group_commit_forward(self):
"""This function will cause host device synchronization"""
# handle synchronization events
self.synchronize_on_group_commit_forward(self.current_group)
super().on_group_commit_forward()
def bulk_reload_group(self, group_to_reload):
"""Bulk reload group."""
assert group_to_reload < self.num_offload_group
buffer_idx = 0
double_buffer_idx = group_to_reload % 2
main_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.h2d_stream):
# move back tensors
for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label
if group_id == group_to_reload:
if isinstance(state, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx]
else:
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state[1], device=torch.cuda.current_device()
)
recovered_tensor = SynchronizedGroupOffloadHandler.reload(
state, True, reload_buffer
)
buffer_idx = buffer_idx + 1
self.tensor_tag_to_state[tensor_label] = recovered_tensor
elif isinstance(state, list):
tensor_list = []
for state_tuple in state:
if isinstance(state_tuple, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][
buffer_idx
]
else:
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state_tuple[1], device=torch.cuda.current_device()
)
tensor_list.append(
SynchronizedGroupOffloadHandler.reload(
state_tuple,
True,
reload_buffer,
)
)
buffer_idx = buffer_idx + 1
else:
tensor_list.append(state_tuple)
# No need to write back the duplicated tensor againn
# to the same location, this check ensures that
if tensor_label in self.dereferencing_list:
self.dereferencing_list.remove(tensor_label)
else:
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(
tensor_list
)
if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor):
self.fp8_tensor_object_map[tensor_label]._transpose_invalid = (
self.float8_transpose_cache_valid.pop(tensor_label)
)
self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop(
tensor_label
)
def on_group_commit_backward(self):
# first decrement the current group.
# after last commit in forward, the group will +1; in backward it -1.
# Finally it should be decremented to 0.
self.current_group -= 1
assert self.current_group >= 0
# Layer window data structure helps us to reload at right times
if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group:
# Stream synchronization both ways
self.h2d_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().wait_stream(self.h2d_stream)
# Time to reload the next group
self.bulk_reload_group(self.offloaded_group_count - 1)
# Decrease the offloading group counter
self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0
# Last group computation needs to wait till all the reloads complete
if self.current_group == 0:
torch.cuda.current_stream().wait_stream(self.h2d_stream)
self.offloaded_group_count = 0
def get_cpu_offload_context(
enabled: bool = False,
num_layers: int = 1,
model_layers: int = 1,
offload_activations: bool = True,
offload_weights: bool = False,
double_buffering: bool = False,
):
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
used after every transformer layer. Returns `nullcontext()` if offloading is not enabled.
Usage:
.. code-block:: python
cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True)
with cpu_offload_context:
te_layer.forward(inp_tensor)
cpu_offload_synchronizer()
Parameters
----------
enabled: bool, default = `False`
When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1
Determines the number of transformer layers
you want to offload activations/weights for.
model_layers: int, default = 1
Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True`
When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True`
When set to `True`, offloads the weights for the TE layer.
double_buffering: bool, default = `False`
When set to `True`, uses double buffering for offloading.
"""
if not offload_weights and not offload_activations:
raise ValueError(
"CPU Offloading is enabled while it is not "
"mentioned what to offload (weights/activations)"
)
if offload_weights:
import warnings
warnings.warn(
"Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect.",
DeprecationWarning,
)
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if not offload_activations:
return nullcontext(), lambda x: x
def tensor_need_offloading_checker_activations(tensor):
return hasattr(tensor, "activation_offloading")
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers,
num_model_group=model_layers,
tensor_need_offloading_checker=tensor_need_offloading_checker,
double_buffering=double_buffering,
)
def group_prefetch_offload_commit_async(tensor):
return group_prefetch_offload_commit(tensor, cpu_offload_handler)
if enabled:
return (
CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler),
group_prefetch_offload_commit_async,
)
return nullcontext(), group_prefetch_offload_commit_async
......@@ -41,7 +41,7 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled
from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..quantized_tensor import (
......@@ -135,6 +135,9 @@ class _GroupedLinear(torch.autograd.Function):
else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)
if cpu_offloading:
start_offload(*inputmats)
# Initialize weights
weights_fp8: list
if fp8:
......@@ -196,6 +199,9 @@ class _GroupedLinear(torch.autograd.Function):
for i in range(num_gemms):
weight_quantizers[i].calibrate(weights[i])
if cpu_offloading:
mark_not_offload(*weights_fp8, *weights)
if is_grad_enabled:
ctx.weight_quantizers = weight_quantizers
ctx.weights_shape_1 = weights[0].shape[1]
......
......@@ -66,10 +66,15 @@ from ..quantized_tensor import (
from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import (
general_gemm,
......@@ -158,6 +163,9 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
if is_cpu_offload_enabled():
start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group)
weight_requires_grad = weight.requires_grad
......@@ -434,8 +442,14 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
mark_not_offload(
weightmat,
weight,
bias,
ln_weight,
ln_bias,
)
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
......@@ -542,6 +556,7 @@ class _LayerNormLinear(torch.autograd.Function):
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
......
......@@ -69,7 +69,12 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ..quantized_tensor import (
QuantizedTensorStorage,
Quantizer,
......@@ -235,6 +240,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight = cast_if_needed(ln_weight, activation_dtype)
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if is_cpu_offload_enabled():
start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group)
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
......@@ -577,6 +584,18 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(act_out)
act_out = None
if cpu_offloading:
mark_not_offload(
ln_weight,
ln_bias,
fc1_weight_final,
fc1_weight,
fc1_bias,
fc2_weight_final,
fc2_weight,
fc2_bias,
)
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
ln_weight,
......
......@@ -68,7 +68,12 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.utils import is_custom
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["Linear"]
......@@ -229,6 +234,9 @@ class _Linear(torch.autograd.Function):
else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
inputmat_total = inputmat
if is_cpu_offload_enabled():
start_offload(inputmat)
nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# ------------------------------------------------------
# Input tensor is ready for GEMM...
......@@ -417,6 +425,7 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
ctx.weight_object = weight
mark_not_offload(weight, weightmat, bias)
# TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat,
......
......@@ -372,9 +372,9 @@ class FusedAdam(torch.optim.Optimizer):
"""
dtype = self.name_to_dtype_map[state_name]
if store_param_remainders:
data = torch.zeros_like(param, dtype=torch.int16)
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
else:
data = torch.empty_like(param, dtype=dtype)
data = torch.empty(param.shape, dtype=dtype, device=param.device)
if zero_buffer:
data.zero_()
......
......@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc
import copy
import warnings
import math
import torch
from torch.utils._pytree import tree_map
......@@ -20,6 +21,11 @@ from transformer_engine.pytorch.tensor._quantization_helpers import (
_stride_from_shape,
)
_quantized_tensor_cpu_supported_ops = (
torch.ops.aten.empty_like.default,
torch.ops.aten.copy_.default,
)
class QuantizedTensorStorage:
r"""Base class for all *TensorStorage classes.
......@@ -35,7 +41,7 @@ class QuantizedTensorStorage:
XTensorStorage should contain all data members needed to
implement the functionality of the tensor, while
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
to behave like regular torch.Tensor (like __torch_dispatch__)."""
_quantizer: Optional[Quantizer]
......@@ -63,6 +69,12 @@ class QuantizedTensorStorage:
f"{self.__class__.__name__} class does not implement update_usage function"
)
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement get_usages function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward"""
raise NotImplementedError(
......@@ -128,6 +140,7 @@ def prepare_for_saving(
t, t_obj = tensor.prepare_for_saving()
tensor_list.extend(t)
tensor_objects_list.append(t_obj)
return tensor_list, tensor_objects_list
......@@ -314,6 +327,13 @@ class Quantizer(abc.ABC):
"""Returns whether or not given tensor can be quantized"""
return True
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the quantizer"""
return {
"rowwise": self.rowwise_usage,
"columnwise": self.columnwise_usage,
}
class QuantizedTensor(torch.Tensor):
"""Abstract base class for tensor with quantized data
......@@ -325,7 +345,14 @@ class QuantizedTensor(torch.Tensor):
"""
def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False):
def __new__(
cls,
shape: Iterable[int],
dtype: torch.dtype,
*,
requires_grad: bool = False,
device: Optional[torch.device] = None,
):
# We are assuming only contiguous tensors
stride = _stride_from_shape(shape)
instance = torch.Tensor._make_wrapper_subclass(
......@@ -336,7 +363,7 @@ class QuantizedTensor(torch.Tensor):
dtype=dtype,
layout=torch.strided,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
device=torch.cuda.current_device() if device is None else device,
)
return instance
......@@ -366,6 +393,9 @@ class QuantizedTensor(torch.Tensor):
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement clear function"
)
def __repr__(self, *, tensor_contents=None) -> str:
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
......@@ -407,6 +437,26 @@ class QuantizedTensor(torch.Tensor):
if func == torch.ops.aten.copy_.default:
dst = args[0]
src = args[1]
if (
isinstance(dst, QuantizedTensor)
and isinstance(src, QuantizedTensor)
and type(dst._quantizer) is type(src._quantizer)
and set(src.get_usages().keys()) == set(dst.get_usages().keys())
and all(
src.get_usages()[usage] == dst.get_usages()[usage]
for usage in src.get_usages().keys()
)
):
dst_tensors, dst_tensor_obj = dst.prepare_for_saving()
src_tensors, src_tensor_obj = src.prepare_for_saving()
for dst_tensor, src_tensor in zip(dst_tensors, src_tensors):
if dst_tensor is not None:
dst_tensor.copy_(src_tensor, *args[2:], **kwargs)
dst_tensor_obj.restore_from_saved(dst_tensors)
src_tensor_obj.restore_from_saved(src_tensors)
return None
if isinstance(dst, QuantizedTensor):
dst.quantize_(src)
else:
......@@ -419,6 +469,36 @@ class QuantizedTensor(torch.Tensor):
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")
# Empty like op
if func == torch.ops.aten.empty_like.default:
tensor = args[0]
device = kwargs.get("device", tensor.device)
requires_grad = kwargs.get("requires_grad", tensor.requires_grad)
pin_memory = kwargs.get("pin_memory", False)
usage = tensor.get_usages()
quantizer_usage = tensor._quantizer.get_usages()
tensor._quantizer.set_usage(**usage)
out = tensor._quantizer.make_empty(
shape=tensor.shape,
dtype=tensor.dtype,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory,
)
tensor._quantizer.set_usage(**quantizer_usage)
return out
if func == torch.ops.aten.numel.default:
tensor = args[0]
return math.prod(tensor.size())
if func == torch.ops.aten.is_pinned.default:
tensor = args[0]
for t in tensor.get_data_tensors():
if t is not None:
return func(t)
return False # Or error out?
def maybe_unwrap(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize(dtype=arg.dtype)
......@@ -463,6 +543,16 @@ class QuantizedTensor(torch.Tensor):
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
def check_if_cpu(arg):
if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu":
assert (
func in _quantized_tensor_cpu_supported_ops
), f"QuantizedTensor on CPU does not support this operation: {func}"
return arg
args = tree_map(check_if_cpu, args)
# Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
......
......@@ -214,6 +214,7 @@ class Float8BlockQuantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data"""
if device is None:
......@@ -229,12 +230,13 @@ class Float8BlockQuantizer(Quantizer):
data = None
scale_inv = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device)
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(
scale_shape,
dtype=torch.float32,
device=device,
pin_memory=pin_memory,
)
# Allocate FP8 data transpose if needed
......@@ -242,13 +244,17 @@ class Float8BlockQuantizer(Quantizer):
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty(
self.get_columnwise_shape(shape), dtype=torch.uint8, device=device
self.get_columnwise_shape(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape,
dtype=torch.float32,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
......
......@@ -101,6 +101,7 @@ class Float8Quantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8Tensor:
# Canonicalize tensor attributes
......@@ -108,16 +109,19 @@ class Float8Quantizer(Quantizer):
device = torch.device("cuda")
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
data = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed
data_transpose = None
if self.columnwise_usage:
transpose_shape = [data.size(-1)] + list(data.shape[:-1])
transpose_shape = [shape[-1]] + list(shape[:-1])
data_transpose = torch.empty(
transpose_shape,
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
......@@ -125,7 +129,7 @@ class Float8Quantizer(Quantizer):
shape=shape,
dtype=dtype,
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device),
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory),
fp8_dtype=self.dtype,
requires_grad=requires_grad,
data_transpose=data_transpose,
......@@ -287,6 +291,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8Tensor:
# Canonicalize tensor attributes
......@@ -294,23 +299,26 @@ class Float8CurrentScalingQuantizer(Quantizer):
device = torch.device("cuda")
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
data = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed
data_transpose = None
if self.columnwise_usage:
transpose_shape = [data.size(-1)] + list(data.shape[:-1])
transpose_shape = [shape[-1]] + list(shape[:-1])
data_transpose = torch.empty(
transpose_shape,
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
return Float8Tensor(
shape=shape,
dtype=dtype,
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device),
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory),
fp8_dtype=self.dtype,
requires_grad=requires_grad,
data_transpose=data_transpose,
......@@ -715,14 +723,22 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
return cls.detach(args[0])
if func == torch.ops.aten.clone.default:
return cls.clone(args[0])
if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1]
# Just copy FP8 attrs if copying between Float8Tensors
if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor):
dst._data.copy_(src._data.detach())
dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size()))
if src._transpose is not None or dst._transpose is not None:
dst._create_transpose()
if dst._data is not None:
dst._data.copy_(src._data.detach(), *args[2:], **kwargs)
if dst._scale_inv is not None:
dst._scale_inv.copy_(
src._scale_inv.view(dst._scale_inv.size()), *args[2:], **kwargs
)
if dst._transpose is not None and not dst._transpose_invalid:
if not src._transpose_invalid:
dst._transpose.copy_(src._transpose, *args[2:], **kwargs)
else:
dst._create_transpose()
return dst
elif func in _ops_to_preserve_subclass_in_fsdp2:
# Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2
......
......@@ -90,6 +90,7 @@ class MXFP8Quantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> MXFP8Tensor:
# Canonicalize tensor attributes
......@@ -105,24 +106,29 @@ class MXFP8Quantizer(Quantizer):
)
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8,
device=device,
)
data = None
scale_inv = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Allocate FP8 data transpose if needed
columnwise_data = None
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty_like(data)
columnwise_data = torch.empty_like(data, pin_memory=pin_memory)
columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
......@@ -348,11 +354,17 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
)
if rowwise_matches and columnwise_matches:
if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach())
dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach())
dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs)
dst._rowwise_scale_inv.copy_(
src._rowwise_scale_inv.detach(), *args[2:], **kwargs
)
if dst._columnwise_data is not None:
dst._columnwise_data.copy_(src._columnwise_data.detach())
dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach())
dst._columnwise_data.copy_(
src._columnwise_data.detach(), *args[2:], **kwargs
)
dst._columnwise_scale_inv.copy_(
src._columnwise_scale_inv.detach(), *args[2:], **kwargs
)
return dst
# FSDP2 related functions.
......
......@@ -6,7 +6,7 @@
from __future__ import annotations
from collections.abc import Iterable
import math
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import functools
import torch
......@@ -265,6 +265,7 @@ class NVFP4Quantizer(Quantizer):
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
pin_memory: bool = False,
requires_grad: bool = False,
) -> NVFP4Tensor:
......@@ -288,11 +289,18 @@ class NVFP4Quantizer(Quantizer):
scale_inv = None
amax_rowwise = None
if self.rowwise_usage:
data = torch.empty(self.convert_shape_for_fp4(shape), dtype=torch.uint8, device=device)
data = torch.empty(
self.convert_shape_for_fp4(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(scale_shape, dtype=torch.uint8, device=device)
scale_inv = torch.empty(
scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
# Allocate per tensor scale inverse. FP32 format.
amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device)
amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed
columnwise_data = None
......@@ -306,12 +314,15 @@ class NVFP4Quantizer(Quantizer):
self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape, dtype=torch.uint8, device=device
columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
amax_columnwise = torch.zeros(
1, dtype=torch.float32, device=device, pin_memory=pin_memory
)
amax_columnwise = torch.zeros(1, dtype=torch.float32, device=device)
# Construct FP8 tensor
return NVFP4Tensor(
......@@ -498,6 +509,12 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
return self
raise ValueError("NVFP4Tensor does not support different memory formats!")
def get_usages(self) -> Dict[str, bool]:
return {
"rowwise": self._rowwise_data is not None,
"columnwise": self._columnwise_data is not None,
}
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
......@@ -520,16 +537,20 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
)
if tensor._rowwise_data is not None:
rowwise_data = data_init_func(tensor._rowwise_data)
rowwise_scale_inv = scale_inv_init_func(tensor._rowwise_scale_inv)
amax_rowwise = torch.zeros_like(tensor._amax_rowwise)
rowwise_data = data_init_func(tensor._rowwise_data, *args[1:], **kwargs)
rowwise_scale_inv = scale_inv_init_func(
tensor._rowwise_scale_inv, *args[1:], **kwargs
)
amax_rowwise = torch.zeros_like(tensor._amax_rowwise, *args[1:], **kwargs)
else:
rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None
if tensor._columnwise_data is not None:
columnwise_data = data_init_func(tensor._columnwise_data)
columnwise_scale_inv = scale_inv_init_func(tensor._columnwise_scale_inv)
amax_columnwise = torch.zeros_like(tensor._amax_columnwise)
columnwise_data = data_init_func(tensor._columnwise_data, *args[1:], **kwargs)
columnwise_scale_inv = scale_inv_init_func(
tensor._columnwise_scale_inv, *args[1:], **kwargs
)
amax_columnwise = torch.zeros_like(tensor._amax_columnwise, *args[1:], **kwargs)
else:
columnwise_data, columnwise_scale_inv, amax_columnwise = (
None,
......
......@@ -420,3 +420,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
return
return
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
return {
"rowwise": self._rowwise_data is not None,
"columnwise": self._columnwise_data is not None,
}
......@@ -225,3 +225,12 @@ class Float8TensorStorage(QuantizedTensorStorage):
if not needs_data_transpose:
self._transpose = None
self._transpose_invalid = True
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
usages = {"rowwise": self._data is not None}
if is_non_tn_fp8_gemm_supported():
usages["columnwise"] = self._data is not None
else:
usages["columnwise"] = self._transpose is not None and not self._transpose_invalid
return usages
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment