Unverified Commit c5257605 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Activation offloading refactor (#1762)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* offloading
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* all types
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* typo
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* api change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* refactor
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* example
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cpu offload + debug warning
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* test
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change empty_like implementation to use make_like
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* main_grad fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* manual synchornization
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* old path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* remove example
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* api changes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* reverted grouped linear
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* make odl code path work for modules
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* attention old code path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* legacy tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* legacy tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* updated code path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/tensor/quantized_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* nvfp4 support
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update tests/pytorch/test_cpu_offloading.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* small fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* docs change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarroot <root@ptyche0312.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a0754757
...@@ -42,7 +42,8 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml ...@@ -42,7 +42,8 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
......
...@@ -2,27 +2,41 @@ ...@@ -2,27 +2,41 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import random
import contextlib import contextlib
import gc
import os
from typing import Iterable, Optional
import pytest import pytest
import os
import torch import torch
from typing import Optional, List
from transformer_engine.pytorch.cpu_offload import (
get_cpu_offload_context,
OffloadableLayerState,
DefaultOffloadSynchronizer,
start_offload,
mark_not_offload,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from utils import ModelConfig
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported import transformer_engine_torch as tex
from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes # Check supported quantization schemes
fp8_available = te.is_fp8_available() fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
mxfp8_available = te.is_mxfp8_available() fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
nvfp4_available, _ = FP8GlobalStateManager.is_nvfp4_available()
quantization_recipes: Optional[recipe.Recipe] = [None] quantization_recipes: List[Optional[recipe.Recipe]] = [None]
if fp8_available: if fp8_available:
quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
if fp8_block_scaling_available:
quantization_recipes.append(recipe.Float8BlockScaling())
if mxfp8_available:
quantization_recipes.append(recipe.MXFP8BlockScaling())
if nvfp4_available:
quantization_recipes.append(recipe.NVFP4BlockScaling())
model_config = { model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
...@@ -32,181 +46,709 @@ NUM_HEADS = model_config["small"].num_heads ...@@ -32,181 +46,709 @@ NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass # Disable garbage collection to tests if there are reference cycles.
# that cannot be offloaded to CPU. # We do not want them, because they can result in CUDA out of memory errors.
assert os.getenv("NVTE_FLASH_ATTN") == "0" import gc
gc.disable()
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required. class Utils:
# tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16)
# For the TransformerLayer, activation offloading with dropout is not supported, _B = 64
# so we set hidden_dropout to 0.0. _S = 256
model_types = { _H = 4
"linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16), _D = 256
"layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16), @staticmethod
"multihead_attention": lambda: te.MultiheadAttention( def long_job(stream: Optional[torch.cuda.Stream] = None):
SIZE, NUM_HEADS, params_dtype=torch.bfloat16 NUM_ITERS = 6000
), if stream is None:
"transformer_layer": lambda: te.TransformerLayer( stream = torch.cuda.current_stream()
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
), with torch.cuda.stream(stream):
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), for i in range(NUM_ITERS):
"layernorm_mlp_ops": lambda: te.ops.Sequential( Utils.tensor1.normal_()
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), @staticmethod
def measure_time(func):
import time
torch.cuda.synchronize()
start = time.time()
func()
torch.cuda.synchronize()
end = time.time()
return (end - start) * 1000
@staticmethod
def get_cuda_memory_mb():
return torch.cuda.memory_allocated() / (1024**2)
@staticmethod
def get_max_cuda_memory_mb():
return torch.cuda.max_memory_allocated() / (1024**2)
@staticmethod
def get_cpu_memory_mb() -> float:
import psutil, os
return psutil.Process(os.getpid()).memory_info().rss / (1024**2)
@staticmethod
def get_layer_names():
return [
"linear",
"layernorm_linear",
"layernorm_mlp",
"grouped_linear",
"multihead_attention",
"transformer_layer",
"linear_op",
"layernorm_mlp_ops",
]
@staticmethod
def create_layer(layer_type: str):
if layer_type == "linear":
return te.Linear(Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "layernorm_linear":
return te.LayerNormLinear(Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "layernorm_mlp":
return te.LayerNormMLP(Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "multihead_attention":
return te.MultiheadAttention(
Utils._D, Utils._H, attention_dropout=0.0, params_dtype=torch.bfloat16
)
elif layer_type == "grouped_linear":
return te.GroupedLinear(Utils._H, Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "transformer_layer":
return te.TransformerLayer(
Utils._D,
Utils._D,
Utils._H,
attention_dropout=0.0,
hidden_dropout=0.0,
params_dtype=torch.bfloat16,
)
elif layer_type == "linear_op":
return te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16)
elif layer_type == "layernorm_mlp_ops":
return te.ops.Sequential(
te.ops.LayerNorm(Utils._D, dtype=torch.bfloat16),
te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16),
te.ops.GELU(), te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16),
), )
} else:
raise ValueError(f"Unknown layer type: {layer_type}")
@staticmethod
def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) -> torch.Tensor:
shape = (Utils._B, Utils._S, Utils._D)
tensor = torch.randn(shape, device="cuda", dtype=torch.bfloat16)
if recipe is None:
tensor = tensor.requires_grad_() if requires_grad else tensor
return tensor
elif recipe.delayed():
quantizer = te.tensor.float8_tensor.Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
scale=torch.tensor([1.0], device="cuda"),
amax=torch.tensor([1.0], device="cuda"),
)
return quantizer(tensor)
elif recipe.float8_current_scaling():
quantizer = te.tensor.float8_tensor.Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"
)
return quantizer(tensor)
elif recipe.float8_block_scaling():
quantizer = te.tensor.float8_blockwise_tensor.Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True
)
return quantizer(tensor)
elif recipe.mxfp8():
quantizer = te.tensor.mxfp8_tensor.MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
return quantizer(tensor)
elif recipe.nvfp4():
quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer()
return quantizer(tensor)
@staticmethod
def create_recipe_ctx(recipe: Optional[recipe.Recipe]):
if recipe is None:
return lambda: contextlib.nullcontext()
else:
return lambda: te.fp8_autocast(fp8_recipe=recipe)
@staticmethod
def get_tensor_size_mb(tensor):
if tensor is None:
return 0
if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage):
return sum(Utils.get_tensor_size_mb(t) for t in tensor.get_data_tensors())
else:
return tensor.numel() * tensor.element_size() / (1024**2)
@staticmethod
def memory_leak_check():
# Should be called before each test.
# Only cublas workspaces and some global tensors are allowed to be allocated.
# All other allocations should be released.
# This is a simple check to catch memory leaks.
if Utils.get_cuda_memory_mb() > 1000:
memory_num = Utils.get_cuda_memory_mb()
import gc
gc.collect() # We want next test to be run with clean state.
gc.disable()
raise RuntimeError(f"Memory leak: {memory_num} MB")
class TestsOffloadableLayerState:
@pytest.mark.parametrize("random_num_tensors", [True, False])
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_general(self, random_num_tensors, recipe):
"""
Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers,
for each layer offload random number of random tensors.
Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors.
"""
Utils.memory_leak_check()
NUM_ITERATIONS = 10
stream = torch.cuda.Stream()
offload_layer_state = OffloadableLayerState(
offload_stream=stream,
)
for _ in range(NUM_ITERATIONS):
original_tensors = []
tensors_ids = []
NUM_TENSORS = random.choice([1, 20]) if random_num_tensors else 1
for _ in range(NUM_TENSORS):
tensor = Utils.create_tensor(recipe)
original_tensors.append(tensor)
tensor_id = offload_layer_state.push_tensor(tensor)
assert tensor.device.type == "cuda"
tensors_ids.append(tensor_id)
offload_layer_state.start_offload()
offload_layer_state.release_activation_forward_gpu_memory()
offload_layer_state.start_reload()
for j in range(len(tensors_ids)):
tensor_gpu = offload_layer_state.pop_tensor(tensors_ids[j])
assert tensor_gpu.device.type == "cuda"
assert tensor_gpu.shape == original_tensors[j].shape
assert tensor_gpu.dtype == original_tensors[j].dtype
torch.testing.assert_close(tensor_gpu, original_tensors[j])
offload_layer_state.release_all_memory()
torch.cuda.synchronize()
def test_offload_base_tensor(self):
Utils.memory_leak_check()
stream = torch.cuda.Stream()
offload_layer_state = OffloadableLayerState(
offload_stream=stream,
)
init_cuda_memory = Utils.get_cuda_memory_mb()
x = Utils.create_tensor(None)
x_size = Utils.get_tensor_size_mb(x)
x_1 = x[::2]
x_2 = x[1::2]
start_offload(x_1, offload_base_tensor=True)
start_offload(x_2, offload_base_tensor=True)
x1_id = offload_layer_state.push_tensor(x_1)
x2_id = offload_layer_state.push_tensor(x_2)
del x_1, x_2
offload_layer_state.start_offload()
offload_layer_state.release_activation_forward_gpu_memory()
assert offload_layer_state.get_offloaded_total_size_mb() == pytest.approx(x_size, 0.1)
offload_layer_state.start_reload()
x_1 = offload_layer_state.pop_tensor(x1_id)
x_2 = offload_layer_state.pop_tensor(x2_id)
assert x_1.device.type == "cuda"
assert x_2.device.type == "cuda"
assert torch.allclose(x_1, x[::2])
assert torch.allclose(x_2, x[1::2])
del x
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + x_size, 0.1)
class TestsDefaultOffloadSynchronizer:
@pytest.mark.parametrize("random_num_tensors", [True, False])
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_general(self, random_num_tensors, recipe):
"""
Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers,
for each layer offload random number of random tensors.
Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors.
"""
Utils.memory_leak_check()
NUM_LAYERS = 10
NUM_ITERATIONS = 10
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=NUM_LAYERS,
num_offloaded_layers=NUM_LAYERS - 1,
)
for _ in range(NUM_ITERATIONS):
original_tensors = []
tensors_ids = []
layer_ids = []
for i in range(NUM_LAYERS):
NUM_LAYER_TENSORS = random.randint(1, 10) if random_num_tensors else 1
layer_tensors = []
layer_tensors_ids = []
layer_id = offload_synchronizer.fwd_step()
for _ in range(NUM_LAYER_TENSORS):
tensor = Utils.create_tensor(recipe)
layer_tensors.append(tensor)
tensor_id = offload_synchronizer.push_tensor(tensor)
assert tensor.device.type == "cuda"
layer_tensors_ids.append(tensor_id)
layer_ids.append(layer_id)
tensors_ids.append(layer_tensors_ids)
original_tensors.append(layer_tensors)
for i in range(NUM_LAYERS - 1, -1, -1):
offload_synchronizer.bwd_step(layer_ids[i])
for j in range(len(tensors_ids[i])):
tensor_gpu = offload_synchronizer.pop_tensor(tensors_ids[i][j])
assert tensor_gpu.device.type == "cuda"
assert tensor_gpu.shape == original_tensors[i][j].shape
assert tensor_gpu.dtype == original_tensors[i][j].dtype
torch.testing.assert_close(tensor_gpu, original_tensors[i][j])
offload_synchronizer.finish_part_of_bwd()
torch.cuda.synchronize()
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_memory(self, recipe):
torch.cuda.synchronize()
Utils.memory_leak_check()
NUM_LAYERS = 10
torch.cuda.reset_peak_memory_stats()
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=NUM_LAYERS,
num_offloaded_layers=NUM_LAYERS - 1,
)
init_cuda_memory = Utils.get_cuda_memory_mb()
tensor_ids = []
torch.cuda.synchronize()
for _ in range(NUM_LAYERS):
offload_synchronizer.fwd_step()
tensor = Utils.create_tensor(recipe)
tensor_size = Utils.get_tensor_size_mb(tensor)
tensor_id = offload_synchronizer.push_tensor(tensor)
assert tensor.device.type == "cuda"
tensor_ids.append(tensor_id)
del tensor, tensor_id
torch.cuda.synchronize()
def _make_input() -> torch.Tensor: if recipe is None:
"""Generate random input tensor.""" assert Utils.get_max_cuda_memory_mb() == pytest.approx(
return torch.randn( init_cuda_memory + tensor_size, 0.1
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
) )
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + tensor_size, 0.1)
for i in range(NUM_LAYERS - 1, -1, -1):
offload_synchronizer.bwd_step(i)
tensor_gpu = offload_synchronizer.pop_tensor(tensor_ids[i])
assert tensor_gpu.device.type == "cuda"
del tensor_gpu, tensor_ids[i]
offload_synchronizer.finish_part_of_bwd()
del tensor_ids
torch.cuda.synchronize()
def _warmup_model( if recipe is None:
modules: Iterable[torch.nn.Module], assert Utils.get_max_cuda_memory_mb() == pytest.approx(
quantization_recipe: Optional[recipe.Recipe], init_cuda_memory + tensor_size, 0.1
) -> None: )
"""Perform forward and backward pass""" assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1)
tensor = _make_input()
for module in modules: @pytest.mark.parametrize("recipe", quantization_recipes)
with te.autocast( def test_multiple_tensor_offload(self, recipe):
enabled=quantization_recipe is not None, Utils.memory_leak_check()
recipe=quantization_recipe, init_cpu_memory = Utils.get_cpu_memory_mb()
init_cuda_memory = Utils.get_cuda_memory_mb()
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=2,
num_offloaded_layers=1,
)
x1 = Utils.create_tensor(recipe)
x_size = Utils.get_tensor_size_mb(x1)
offload_synchronizer.fwd_step()
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
offload_synchronizer.fwd_step()
# Only one copy of tensor on cpu is allocated.
assert Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1)
del x1
offload_synchronizer.bwd_step(1)
offload_synchronizer.bwd_step(0)
offload_synchronizer.finish_part_of_bwd()
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1)
class TestTELayers:
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_sanity(self, layer_type, recipe):
Utils.memory_leak_check()
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
): ):
tensor = module(tensor) pytest.skip("Fusible operations do not support FP8 block scaling recipe")
tensor.sum().backward()
recipe_ctx = Utils.create_recipe_ctx(recipe)
init_cuda_memory = Utils.get_cuda_memory_mb()
OFFLOAD_LAYERS = 6
NUM_LAYERS = 10
offload_ctx, sync_function = get_cpu_offload_context(
enabled=True,
num_layers=OFFLOAD_LAYERS,
model_layers=NUM_LAYERS,
)
layers = [Utils.create_layer(layer_type) for _ in range(NUM_LAYERS)]
inp = Utils.create_tensor(None)
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
)
out = inp
for i in range(NUM_LAYERS):
with offload_ctx, recipe_ctx():
# Ops-based layers don't support is_first_microbatch parameter
if layer_type in ["linear_op", "layernorm_mlp_ops"]:
out = layers[i](out, **m_splits)
else:
out = layers[i](out, is_first_microbatch=False, **m_splits)
out = sync_function(out)
out.sum().backward()
torch.cuda.synchronize()
del out, inp, layers
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_memory(self, layer_type, recipe):
Utils.memory_leak_check()
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
):
pytest.skip("Fusible operations do not support FP8 block scaling recipe")
def _estimate_cached_weight_size( offload_ctx, sync_function = get_cpu_offload_context(
model_name: str, enabled=True,
modules: Iterable[torch.nn.Module], num_layers=1,
quantization_recipe: Optional[recipe.Recipe], model_layers=2,
) -> float: offload_activations=True,
"""Calculate the memory (in MiB) needed for weight caching.""" offload_weights=False,
)
recipe_ctx = Utils.create_recipe_ctx(recipe)
layer = Utils.create_layer(layer_type)
inp = Utils.create_tensor(None)
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
)
# The weight params are cached directly for unquantized compute # Ops-based layers don't support is_first_microbatch parameter
if quantization_recipe is None: is_ops_layer = layer_type in ["linear_op", "layernorm_mlp_ops"]
return 0
# Count number of weight param elements with recipe_ctx():
param_elements = 0 if is_ops_layer:
for module in modules: out = layer(inp, **m_splits)
for param in module.parameters(): else:
if param.dim() == 2: out = layer(inp, is_first_microbatch=True, **m_splits)
param_elements += param.numel() out.sum().backward()
# FP8 tensor-scaling caches one byte per element del inp
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling(): init_cuda_memory = Utils.get_cuda_memory_mb()
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op", # run layer without offload
"layernorm_mlp_ops", inp = Utils.create_tensor(None)
with recipe_ctx():
if is_ops_layer:
out = layer(inp, **m_splits)
else:
out = layer(inp, is_first_microbatch=False, **m_splits)
with recipe_ctx():
out = out + 1
del inp
cuda_memory_no_offload = Utils.get_cuda_memory_mb()
out.sum().backward()
# run layer with offload
inp = Utils.create_tensor(None)
with offload_ctx, recipe_ctx():
if is_ops_layer:
out = layer(inp, **m_splits)
else:
out = layer(inp, is_first_microbatch=False, **m_splits)
out = sync_function(out)
with offload_ctx, recipe_ctx():
out = out + 1
out = sync_function(out)
del inp
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1)
offloaded_memory_cpu = offload_ctx.offload_synchronizer.get_offloaded_total_size_mb()
# This assertion verifies that the memory used by tensors on the CPU matches the memory saved from a layer.
# It helps catch cases where an offloaded tensor still has a live pointer, which would
# cause an unnecessary copy to the CPU and prevent GPU memory from being released.
assert Utils.get_cuda_memory_mb() + offloaded_memory_cpu == pytest.approx(
cuda_memory_no_offload, 0.1
)
out.sum().backward()
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_manual_synchronization(self, recipe, layer_type):
Utils.memory_leak_check()
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
): ):
# Modules do not deallocate FP8 transpose for weights pytest.skip("Fusible operations do not support FP8 block scaling recipe")
return 2 * param_elements / 1024**2
return param_elements / 1024**2
# MXFP8 caches one data byte per element and one scale byte per 32 offload_ctx, sync_function, manual_controller = get_cpu_offload_context(
# elements enabled=True,
if quantization_recipe.mxfp8(): model_layers=6,
if model_name not in ("linear_op", "layernorm_mlp_ops"): offload_activations=True,
# Modules do not deallocate column-wise MXFP8 data for weights manual_synchronization=True,
return 2 * param_elements * (1 + 1 / 32) / 1024**2 )
return param_elements * (1 + 1 / 32) / 1024**2 layer_1 = Utils.create_layer(layer_type)
layer_2 = Utils.create_layer(layer_type)
inp1 = Utils.create_tensor(None)
inp2 = Utils.create_tensor(None)
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})") recipe_ctx = Utils.create_recipe_ctx(recipe)
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
)
def _measure_cached_memory( init_cuda_memory = Utils.get_cuda_memory_mb()
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe], # 1 fwd
cpu_offload: bool, with offload_ctx, recipe_ctx():
) -> float: out_1 = layer_1(inp1, **m_splits)
"""Measure the growth in allocated GPU memory in MiB after a model forward pass. out_1 = sync_function(out_1)
with offload_ctx, recipe_ctx():
out_2 = layer_2(inp2, **m_splits)
out_2 = sync_function(out_2)
mark_not_offload(out_1, out_2)
del inp1, inp2
memory_before_offload = Utils.get_cuda_memory_mb()
manual_controller.start_offload_layer(0)
manual_controller.release_activation_forward_gpu_memory(0)
manual_controller.start_offload_layer(1)
manual_controller.release_activation_forward_gpu_memory(1)
memory_after_offload = Utils.get_cuda_memory_mb()
assert memory_after_offload + EPSILON < memory_before_offload
manual_controller.start_reload_layer(0)
manual_controller.start_reload_layer(1)
memory_after_reload = Utils.get_cuda_memory_mb()
assert memory_after_reload == pytest.approx(memory_before_offload, 0.1)
out_1.sum().backward()
out_2.sum().backward()
@pytest.mark.parametrize("recipe", quantization_recipes)
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("use_cuda_graphs", [True, False])
@pytest.mark.parametrize("retain_pinned_cpu_buffers", [True, False])
@pytest.mark.parametrize("backend", ["FlashAttention", "FusedAttention", "UnfusedAttention"])
def test_numerics(
self,
recipe,
layer_type,
use_cuda_graphs,
backend,
retain_pinned_cpu_buffers,
):
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
):
pytest.skip("Fusible operations do not support FP8 block scaling recipe")
Memory measurement excludes the input and output tensors. recipe_ctx = Utils.create_recipe_ctx(recipe)
""" if use_cuda_graphs and not retain_pinned_cpu_buffers:
pytest.skip(
"Cuda graphs are not yet supported with cpu offloading when"
" retain_pinned_cpu_buffers is False."
)
# Reset memory if backend == "FusedAttention" and use_cuda_graphs:
gc.collect() pytest.skip(
torch.cuda.empty_cache() "Fused attention + cuda graphs is temporarily broken, not because of cpu offloading"
)
# Context and sync function for CPU offloading os.environ["NVTE_FLASH_ATTN"] = "0"
if cpu_offload: os.environ["NVTE_FUSED_ATTN"] = "0"
offload_context, sync_function = te.get_cpu_offload_context( os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
elif backend == "UnfusedAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
offload_ctx, sync_function = get_cpu_offload_context(
enabled=True, enabled=True,
num_layers=len(modules), num_layers=1,
model_layers=len(modules) + 1, model_layers=2,
offload_activations=True, offload_activations=True,
offload_weights=False, offload_weights=False,
retain_pinned_cpu_buffers=retain_pinned_cpu_buffers,
)
class Callable(torch.nn.Module):
def __init__(self, offload_ctx=None, sync_function=None):
super().__init__()
self.layers = torch.nn.ModuleList(
[Utils.create_layer(layer_type) for _ in range(2)]
)
self.offload_ctx = offload_ctx
self.sync_function = sync_function
def forward(self, x):
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
) )
is_ops_layer = layer_type in ["linear_op", "layernorm_mlp_ops"]
for layer in self.layers:
with self.offload_ctx, recipe_ctx():
if is_ops_layer:
x = layer(x, **m_splits)
else: else:
offload_context = contextlib.nullcontext() x = layer(x, is_first_microbatch=False, **m_splits)
sync_function = lambda x: x if self.sync_function is not None:
x = self.sync_function(x)
# Forward pass, with dummy step to trigger offload for last module return x
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context:
tensor = module(tensor)
tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
# Backward pass
tensor.sum().backward()
torch.cuda.synchronize()
# Memory usage in MiB callable_offload = Callable(offload_ctx=offload_ctx, sync_function=sync_function)
return memory_after_forward - memory_before_forward callable_no_offload = Callable(offload_ctx=contextlib.nullcontext(), sync_function=None)
# copy parameters
for param_offload, param_no_offload in zip(
callable_offload.parameters(), callable_no_offload.parameters()
):
param_offload.data.copy_(param_no_offload.data)
@pytest.mark.parametrize("quantization_recipe", quantization_recipes) x = Utils.create_tensor(None)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
"""Check that CPU offloading runs and has expected memory usage."""
# Construct model if use_cuda_graphs:
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)] callable_offload = te.make_graphed_callables(
if model_name in ["multihead_attention", "transformer_layer"]: callable_offload,
available_backends, *_ = get_available_attention_backends( (x,),
model_config["small"], enabled=recipe is not None,
qkv_dtype=torch.bfloat16, recipe=(Utils.create_recipe_ctx(recipe) if recipe is not None else None),
qkv_layout="sbhd_sbhd_sbhd",
) )
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
# Warmup # warm up (for example to compute sf for delayed scaling)
_warmup_model(modules_list, quantization_recipe) for _ in range(4):
out = callable_offload(x)
out.sum().backward()
out = callable_no_offload(x)
out.sum().backward()
callable_offload.zero_grad(set_to_none=True)
out_offload = callable_offload(x)
out_offload.sum().backward()
# save out and gradients
offload_outs = [out_offload]
for param in callable_offload.parameters():
offload_outs.append(param.detach().clone())
# Measure cached memory after forward pass torch.cuda.reset_peak_memory_stats()
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False) out_no_offload = callable_no_offload(x)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True) out_no_offload.sum().backward()
# Check for expected memory usage # collect gradients
assert memory_with_offload < memory_without_offload no_offload_outs = [out_no_offload]
memory_from_cached_weights = _estimate_cached_weight_size( for param in callable_no_offload.parameters():
model_name, no_offload_outs.append(param.detach().clone())
modules_list,
quantization_recipe, # check if tensors are the same
for i in range(len(offload_outs)):
assert torch.allclose(offload_outs[i], no_offload_outs[i]), f"Error in tensor {i}."
torch.cuda.synchronize()
def test_example_from_doc(self):
offload_stream = torch.cuda.Stream()
num_layers = 10
layers = [Utils.create_layer("transformer_layer") for _ in range(num_layers)]
inp = [Utils.create_tensor(None) for _ in range(num_layers)]
out = [None] * num_layers
cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
enabled=True,
model_layers=num_layers,
manual_synchronization=True,
offload_stream=offload_stream,
) )
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
for i in range(num_layers):
with cpu_offload_context:
out[i] = layers[i].forward(inp[i])
out[i] = sync_function(out[i])
manual_controller.start_offload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
manual_controller.release_activation_forward_gpu_memory(i)
for i in range(num_layers - 1, -1, -1):
# these calls are intended to be done in the backward pass
manual_controller.start_reload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
out[i].sum().backward()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import contextlib
import gc
import os
from typing import Iterable, Optional
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes
fp8_available = te.is_fp8_available()
mxfp8_available = te.is_mxfp8_available()
quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available:
quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
# CPU offload v1 code path is enabled
assert os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1"
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
"linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
"multihead_attention": lambda: te.MultiheadAttention(
SIZE, NUM_HEADS, params_dtype=torch.bfloat16
),
"transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
),
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
"layernorm_mlp_ops": lambda: te.ops.Sequential(
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
),
}
def _make_input() -> torch.Tensor:
"""Generate random input tensor."""
return torch.randn(
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
def _warmup_model(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> None:
"""Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None,
recipe=quantization_recipe,
):
tensor = module(tensor)
tensor.sum().backward()
def _estimate_cached_weight_size(
model_name: str,
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> float:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if quantization_recipe is None:
return 0
# Count number of weight param elements
param_elements = 0
for module in modules:
for param in module.parameters():
if param.dim() == 2:
param_elements += param.numel()
# FP8 tensor-scaling caches one byte per element
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op",
"layernorm_mlp_ops",
):
# Modules do not deallocate FP8 transpose for weights
return 2 * param_elements / 1024**2
return param_elements / 1024**2
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if quantization_recipe.mxfp8():
if model_name not in ("linear_op", "layernorm_mlp_ops"):
# Modules do not deallocate column-wise MXFP8 data for weights
return 2 * param_elements * (1 + 1 / 32) / 1024**2
return param_elements * (1 + 1 / 32) / 1024**2
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")
def _measure_cached_memory(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
cpu_offload: bool,
) -> float:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
Memory measurement excludes the input and output tensors.
"""
# Reset memory
gc.collect()
torch.cuda.empty_cache()
# Context and sync function for CPU offloading
if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(
enabled=True,
num_layers=len(modules),
model_layers=len(modules) + 1,
offload_activations=True,
offload_weights=False,
)
else:
offload_context = contextlib.nullcontext()
sync_function = lambda x: x
# Forward pass, with dummy step to trigger offload for last module
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context:
tensor = module(tensor)
tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
# Backward pass
tensor.sum().backward()
torch.cuda.synchronize()
# Memory usage in MiB
return memory_after_forward - memory_before_forward
@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
"""Check that CPU offloading runs and has expected memory usage."""
# Construct model
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
if model_name in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
qkv_layout="sbhd_sbhd_sbhd",
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
# Warmup
_warmup_model(modules_list, quantization_recipe)
# Measure cached memory after forward pass
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
# Check for expected memory usage
assert memory_with_offload < memory_without_offload
memory_from_cached_weights = _estimate_cached_weight_size(
model_name,
modules_list,
quantization_recipe,
)
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
...@@ -50,6 +50,13 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel ...@@ -50,6 +50,13 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
) )
from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_activation_offload,
NVTE_CPU_OFFLOAD_V1,
)
from transformer_engine.pytorch.cpu_offload_v1 import is_current_layer_offloaded
# Import attention utils # Import attention utils
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
...@@ -737,6 +744,9 @@ class FlashAttention(torch.nn.Module): ...@@ -737,6 +744,9 @@ class FlashAttention(torch.nn.Module):
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
] ]
if is_cpu_offload_enabled():
start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True)
# get batch_size, max_seqlen and cu_seqlens # get batch_size, max_seqlen and cu_seqlens
batch_size, context_len = None, None batch_size, context_len = None, None
if inference_params is None: if inference_params is None:
...@@ -877,12 +887,7 @@ class FlashAttention(torch.nn.Module): ...@@ -877,12 +887,7 @@ class FlashAttention(torch.nn.Module):
fp8_output=fp8_output, fp8_output=fp8_output,
) )
else: else:
from transformer_engine.pytorch.cpu_offload import ( if is_cpu_offload_enabled():
CPUOffloadEnabled,
mark_activation_offload,
)
if CPUOffloadEnabled:
mark_activation_offload( mark_activation_offload(
query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv
) )
...@@ -1116,6 +1121,9 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1116,6 +1121,9 @@ class FusedAttnFunc(torch.autograd.Function):
nvtx_label = "transformer_engine.FusedAttnFunc.forward" nvtx_label = "transformer_engine.FusedAttnFunc.forward"
nvtx_range_push(f"{nvtx_label}") nvtx_range_push(f"{nvtx_label}")
if is_cpu_offload_enabled():
start_offload(q, k, v, offload_base_tensor=True)
# recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"] # may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
...@@ -1293,12 +1301,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1293,12 +1301,7 @@ class FusedAttnFunc(torch.autograd.Function):
# used when some tensors are base tensors and loose the "dtype" attribute # used when some tensors are base tensors and loose the "dtype" attribute
ctx.nominal_dtype = out_nominal_dtype ctx.nominal_dtype = out_nominal_dtype
from transformer_engine.pytorch.cpu_offload import ( if is_cpu_offload_enabled() and NVTE_CPU_OFFLOAD_V1:
CPUOffloadEnabled,
mark_activation_offload,
)
if CPUOffloadEnabled:
if ctx.fp8: if ctx.fp8:
tensor_list = fp8_tensors tensor_list = fp8_tensors
else: else:
...@@ -1309,6 +1312,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1309,6 +1312,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
*fp8_tensors, *fp8_tensors,
*qkvo_tensors, *qkvo_tensors,
...@@ -1339,14 +1343,11 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1339,14 +1343,11 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill ctx.fast_zero_fill = fast_zero_fill
from transformer_engine.pytorch.cpu_offload import ( if NVTE_CPU_OFFLOAD_V1:
CPUOffloadedLayer,
)
# If interleaved tensor is offloaded, reloaded tensor will be # If interleaved tensor is offloaded, reloaded tensor will be
# non-interleaved, so we need to modify the QKV layout # non-interleaved, so we need to modify the QKV layout
# for backward # for backward
if CPUOffloadedLayer and CPUOffloadEnabled: if is_current_layer_offloaded() and is_cpu_offload_enabled():
reload_layout = "" reload_layout = ""
split_list = qkv_layout.split("_") split_list = qkv_layout.split("_")
for split in split_list: for split in split_list:
...@@ -1362,6 +1363,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1362,6 +1363,8 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout = reload_layout[:-1] ctx.qkv_layout = reload_layout[:-1]
else: else:
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
else:
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
......
...@@ -1494,14 +1494,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1494,14 +1494,6 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_output=fp8_output, fp8_output=fp8_output,
) )
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
warnings.warn(
"Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!"
)
if use_unfused_attention: if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
if checkpoint_core_attention: if checkpoint_core_attention:
......
...@@ -33,6 +33,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProduc ...@@ -33,6 +33,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProduc
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled
# Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast(). # Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast().
# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling" # Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling"
# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa. # and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa.
...@@ -971,7 +973,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -971,7 +973,8 @@ class MultiheadAttention(torch.nn.Module):
# =========================== # ===========================
# Core attention computation # Core attention computation
# =========================== # ===========================
if is_cpu_offload_enabled():
start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True)
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer, query_layer,
key_layer, key_layer,
......
...@@ -3,698 +3,748 @@ ...@@ -3,698 +3,748 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Functionality for CPU offloading of tensors saved for backward pass.""" """Functionality for CPU offloading of tensors saved for backward pass."""
from __future__ import annotations
from contextlib import nullcontext
from typing import Any, Dict, Optional
from __future__ import annotations
import contextlib
from collections import defaultdict
from dataclasses import dataclass, field
import os
import warnings
from typing import Any, Optional
import torch import torch
from torch.autograd.graph import saved_tensors_hooks
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from .quantized_tensor import QuantizedTensorStorage import transformer_engine.pytorch as te
from .tensor.float8_tensor import Float8Tensor import transformer_engine.pytorch.cpu_offload_v1 as v1_code_path
from .quantized_tensor import (
restore_from_saved,
prepare_for_saving,
)
__all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False __all__ = ["get_cpu_offload_context", "mark_not_offload", "start_offload"]
CPUOffloadedLayer = False
NVTE_CPU_OFFLOAD_V1 = os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1"
def mark_activation_offload(*tensors): OFFLOAD_SYNCHRONIZER = None
"""Set the type of the offloading needed for a tensor."""
if TEDebugState.debug_enabled:
raise RuntimeError("CPU offload is not supported in debug mode.")
for tensor in tensors:
if tensor is None:
continue
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
tensor.activation_offloading = True
else:
data_tensors = tensor.get_data_tensors()
for tensor in data_tensors:
if tensor is not None:
tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorStorage classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True
def is_cpu_offload_enabled():
"""Returns True if CPU offload is enabled."""
if NVTE_CPU_OFFLOAD_V1:
return v1_code_path.is_cpu_offload_enabled()
return OFFLOAD_SYNCHRONIZER is not None
def is_cpu_offload_enabled() -> bool:
"""Check if CPU offloading is currently enabled."""
return CPUOffloadEnabled
def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor."""
if NVTE_CPU_OFFLOAD_V1:
v1_code_path.mark_activation_offload(*tensors)
class CpuOffloadSavedTensorHook:
"""Contex-manager that executes a pair of pack/unpack hooks for saved tensors.
In this context, the ``on_save_for_backward`` method will be called every time def mark_not_offload(*tensors: torch.Tensor):
a tensor is saved for backward (this includes intermediary results saved using """Marks tensors to prevent them from being offloaded."""
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but if NVTE_CPU_OFFLOAD_V1:
also those recorded by a PyTorch-defined operation). return
The ``on_get_saved_tensors`` method will be called when the backward function tensors, tensor_obj = prepare_for_saving(*tensors)
of this op attempts to retrieve the saved tensor from context (this includes
:func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the
as input the return value of the ``on_save_for_backward``, and is meant to return
an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of
size, device and element values.
Example: for tensor in tensors:
if tensor is not None:
setattr(tensor, "_TE_do_not_offload", True)
>>> import torch restore_from_saved(tensor_obj, tensors)
>>> from typing import Any
>>>
>>> class DummyHook(CpuOffloadSavedTensorHook):
...
... def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
... logging.info("On save", tensor)
... return (tensor,)
...
... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
... logging.info("On get", saved_state)
... tensor, = saved_state
... return tensor
...
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with DummyHook():
... y = a * b
...
On save tensor([1., 1., 1., 1., 1.], requires_grad=True)
On save tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),)
On get (tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>),)
"""
def __init__(self) -> None:
self.inside_context = False
def __enter__(self):
global CPUOffloadEnabled
CPUOffloadEnabled = True
self.inside_context = True def start_offload(*tensors: torch.Tensor, offload_base_tensor: bool = False):
torch._C._autograd._push_saved_tensors_default_hooks( """
self.on_save_for_backward, self.on_get_saved_tensor Marks point in on main stream where tensors are fully computed and ready to be offloaded.
) If offload_base_tensor is True and the tensor is a view, the base tensor is offloaded
and reloaded - the stride and storage offset of the view are saved and restored after reload.
It is useful when multiple tensors are views of the same base tensor,
for example in MultiHeadAttention for interleaved q, k, v tensors.
"""
if NVTE_CPU_OFFLOAD_V1:
return
def __exit__(self, *args: Any): def _mark_tensor_for_offload(t):
global CPUOffloadEnabled if t is None:
CPUOffloadEnabled = False return
# Attach an event to mark when the tensor is ready for reload.
t.start_reload_event = torch.cuda.Event()
t.start_reload_event.record(torch.cuda.current_stream())
if offload_base_tensor and t._base is not None:
setattr(t, "offload_base_tensor", True)
self.inside_context = False tensors, tensor_obj = prepare_for_saving(*tensors)
torch._C._autograd._pop_saved_tensors_default_hooks()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
"""On save for backward."""
raise NotImplementedError(
"`on_save_for_backward: Callable[[torch.Tensor], Any]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: for tensor in tensors:
"""On get saved tensor.""" _mark_tensor_for_offload(tensor)
raise NotImplementedError(
"`on_get_saved_tensors: Callable[[Any], torch.Tensor]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
restore_from_saved(tensor_obj, tensors)
class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
"""Context-manager that offloads/recovers tensors through an offload hander.
The hook just offloads/recovers the tensor object to the handler through `tensor_push` @dataclass
and `tensor_pop` interface. How the offload-handler manages the offloading, recovering class TensorGroup:
or prefetching timing is transparent to this hook. """
TensorGroup is a collection of tensors, events and auxiliary data.
It is used multiple times in the CPU offload code.
""" """
def __init__( tensor_list: list[torch.Tensor] = field(default_factory=list)
self, events: list[torch.cuda.Event] = field(default_factory=list)
offload_handler: OffloadHandler, aux: Any = None
handler_extra_kwargs: Optional[Dict[str, Any]] = None,
debug: bool = False,
) -> None:
if handler_extra_kwargs is None:
handler_extra_kwargs = {}
self.debug: bool = debug
self.offload_handler: OffloadHandler = offload_handler
self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs
super().__init__()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)
return retrieve_identifier
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs)
return tensor
class OffloadHandler:
"""A base class for CPU offload-handler."""
def __init__(self) -> None:
pass
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: class TensorGroupProcessor:
"""Tensor push.""" """
raise NotImplementedError( Suppose there is a tensor group T that needs to be offloaded.
"`tensor_push is not implented in OffloadHandler class. " Possibly we can switch T into (T_opt, aux), where T_opt is smaller and easier to offload,
"Inherit this class and implement your custom tensor_push." offload T_opt, reload it and then restore T from (T_opt_reloaded, aux).
)
def tensor_pop(self, tensor_tag: Any, **kwargs): This class contains static methods that perform these optimizations - for example
"""Tensor pop.""" deduplication of tensors and restoring duplicates after reload.
raise NotImplementedError( """
"`tensor_pop is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_pop."
)
@staticmethod
def tensor_group_process_before_offload(tensor_group: TensorGroup) -> tuple[TensorGroup, Any]:
"""
Call for a tensor group, just before offloading logic.
class GroupCommitFunction(torch.autograd.Function): aux is a dictionary that contains auxiliary data, needed to restore pre-offload state.
"""this is a dummy op with output identical to input.
However, it is necessary for marking a timepoint for offload handler to
accomplish all synchronizations. Implementing it as a function is necessary
because we need to actions in both forward and backward.
""" """
aux = {}
tensor_group = TensorGroupProcessor._switch_to_base_tensors(aux, tensor_group)
tensor_group = TensorGroupProcessor._deduplicate_tensors(aux, tensor_group)
return tensor_group, aux
@staticmethod @staticmethod
def forward(ctx, tensor, cpu_offload_handler): def tensor_group_process_after_reload(tensor_group: TensorGroup):
# pylint: disable=missing-function-docstring """
cpu_offload_handler.on_group_commit_forward() Call for a tensor group, just after reload logic.
ctx.cpu_offload_handler = cpu_offload_handler """
# return the identical tensor assert tensor_group.aux is not None
return tensor tensor_group = TensorGroupProcessor._restore_tensor_duplicates(tensor_group)
tensor_group = TensorGroupProcessor._switch_to_views(tensor_group)
return tensor_group
@staticmethod @staticmethod
def backward(ctx, grad_output): def _switch_to_base_tensors(aux, tensor_group: TensorGroup) -> TensorGroup:
# pylint: disable=missing-function-docstring """
cpu_offload_handler = ctx.cpu_offload_handler Changes tensors to base tensors and saves view options in aux.
cpu_offload_handler.on_group_commit_backward()
return grad_output, None It we save multiple tensors which in fact are views of the same base tensor,
this will offload only this one base tensor. It is used for example in
MultiHeadAttention for interleaved q, k, v tensors.
"""
def _check_if_offload_base_tensor(tensor: torch.Tensor) -> bool:
if getattr(tensor, "offload_base_tensor", False):
return True
if tensor._base is not None:
# If tensor is a view of a tensor and has the same elements,
# but with different strides, we can safely offload the base tensor.
# If tensor is a view on some part of a bigger tensor,
# the decision to offload the base tensor is non-trivial and we do not do it by default.
return tensor._base.numel() == tensor.numel()
return False
aux["views"] = []
for tensor_id in range( # pylint: disable=consider-using-enumerate
len(tensor_group.tensor_list)
):
tensor = tensor_group.tensor_list[tensor_id]
if _check_if_offload_base_tensor(tensor):
aux["views"].append((tensor.shape, tensor.stride(), tensor.storage_offset()))
tensor = tensor._base
assert (
tensor is not None
), "Cannot offload base tensor, if the tensor is not a view."
tensor_group.tensor_list[tensor_id] = tensor
else:
aux["views"].append(None)
return tensor_group
group_prefetch_offload_commit = GroupCommitFunction.apply @staticmethod
def _deduplicate_tensors(aux, tensor_group: TensorGroup) -> TensorGroup:
"""
Deduplicate tensors.
"""
dedup_tensors: list[torch.Tensor] = []
dedup_events: list[torch.cuda.Event] = []
tensor_to_index: dict[int, int] = {}
aux["original_tensor_ids"] = []
# If there are several duplicates of the same tensor, with different events,
# we keep only first event - every event is recorded when the tensor is ready to be offloaded,
# so it is the most optimal to use the first event.
for tensor_id, tensor in enumerate(tensor_group.tensor_list):
if id(tensor) in tensor_to_index:
aux["original_tensor_ids"].append(tensor_to_index[id(tensor)])
else:
tensor_to_index[id(tensor)] = len(dedup_tensors)
dedup_tensors.append(tensor)
dedup_events.append(tensor_group.events[tensor_id])
aux["original_tensor_ids"].append(tensor_to_index[id(tensor)])
class SynchronizedGroupOffloadHandler(OffloadHandler): tensor_group.tensor_list = dedup_tensors
"""Offload Handler that offloads/reloads in a synchronized way. tensor_group.events = dedup_events
The device-to-host and host-to-device copying happen in the same stream return tensor_group
as the computation kernels, thus the copying will block computation.
@staticmethod
def _restore_tensor_duplicates(tensor_group: TensorGroup) -> TensorGroup:
""" """
Restore tensor duplicates.
"""
new_tensor_list = []
new_events_list = []
for tensor_id in range(len(tensor_group.aux["original_tensor_ids"])):
original_tensor_id = tensor_group.aux["original_tensor_ids"][tensor_id]
new_tensor_list.append(tensor_group.tensor_list[original_tensor_id])
new_events_list.append(tensor_group.events[original_tensor_id])
def __init__( tensor_group.tensor_list = new_tensor_list
self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False tensor_group.events = new_events_list
) -> None: return tensor_group
super().__init__()
self.num_offload_group = num_offload_group
self.tensor_need_offloading_checker = tensor_need_offloading_checker
self.debug = debug
self.groupid_reset()
def groupid_reset(self):
"""Groupid reset."""
# Data structures to label saved tensors and book-keep their cpu copies.
# Currently, on push, create a new cpu tensor and copies; on pop, copies
# the tensor back to gpu and deletes the cpu tensor.
# These will increment whenever `group_commit()` is invoked
self.current_group, self.tensor_count_current_group = (0, 0)
self.torch_tensor_count = 0
self.tensor_tag_to_state = {}
def on_group_commit_forward(self):
"""On group commit forward."""
# finishing up with updating current group and tensor count
self.current_group += 1 # increment
self.tensor_count_current_group = 0 # reset
def on_group_commit_backward(self):
"""On group commit backward."""
self.current_group -= 1
assert self.current_group >= 0
@staticmethod @staticmethod
def offload(src_tensor, pin_memory=True): def _switch_to_views(tensor_group: TensorGroup) -> TensorGroup:
"""Offload.""" """
Switch to views - reverse of _switch_to_base_tensors.
cpu_backup = torch.empty( """
src_tensor.size(), for tensor_id, tensor in enumerate(tensor_group.tensor_list):
dtype=src_tensor.dtype, if tensor_group.aux["views"][tensor_id] is not None:
layout=src_tensor.layout, tensor_group.tensor_list[tensor_id] = tensor.as_strided(
device="cpu", *tensor_group.aux["views"][tensor_id]
pin_memory=pin_memory,
) )
return tensor_group
cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
state = (src_tensor.device, cpu_backup)
return state
@staticmethod class OffloadableLayerState:
def reload(state, non_blocking=None, copy_buffer=None): """
"""Reload.""" Class that manages offloading and reloading of tensors for a single layer.
dev, cpu_backup = state """
if non_blocking is None:
non_blocking = cpu_backup.is_pinned()
if copy_buffer is None: def __init__(
return cpu_backup.to(dev, non_blocking=non_blocking) self,
offload_stream: torch.cuda.Stream,
retain_pinned_cpu_buffers: bool = False,
):
self.offload_stream = offload_stream
self.retain_pinned_cpu_buffers = retain_pinned_cpu_buffers
assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!" # There are 3 tensor groups: tensors on gpu before offload,
# tensors on cpu after offload, tensors on gpu after reload.
self.fwd_gpu_tensor_group = TensorGroup()
self.cpu_tensor_group = TensorGroup()
self.bwd_gpu_tensor_group = TensorGroup()
copy_buffer.copy_(cpu_backup, non_blocking=non_blocking) self.aux: dict[str, Any] = {}
return copy_buffer # State can be one of: not_offloaded, offload_started,
# offload_finished, reload_started.
self.state = "not_offloaded"
def tensor_push(self, tensor: torch.Tensor, **kwargs): def _validate_state(self, func_name: str, allowed_states: list[str]):
"""Tensor push.""" assert (
# obtain a unique tensor tag self.state in allowed_states
tensor_tag = (self.current_group, self.tensor_count_current_group) ), f"Invalid state: {self.state} for {func_name}, must be one of {allowed_states}"
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(
tensor
):
state = SynchronizedGroupOffloadHandler.offload(tensor)
self.tensor_tag_to_state[tensor_tag] = state
else:
# will be offloaded together after group commit
self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag def start_offload(self):
"""
Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream.
Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded.
This event is recorded in the start_offload or push_tensor call.
"""
self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"])
self.state = "offload_started"
def tensor_pop(self, tensor_tag, **kwargs): self.fwd_gpu_tensor_group, aux = TensorGroupProcessor.tensor_group_process_before_offload(
"""Tensor pop.""" self.fwd_gpu_tensor_group
assert tensor_tag in self.tensor_tag_to_state )
state = self.tensor_tag_to_state.pop(tensor_tag)
if isinstance(state, tuple):
tensor = SynchronizedGroupOffloadHandler.reload(state)
else:
tensor = state
return tensor
allocate_cpu_buffers = (
not self.retain_pinned_cpu_buffers or len(self.cpu_tensor_group.tensor_list) == 0
)
class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): for tensor_id, tensor in enumerate(self.fwd_gpu_tensor_group.tensor_list):
"""Compared to synchronize, this uses more memory because of the buffer but assert tensor.is_contiguous()
achieves better performance due to the overlapping. D2h and h2d copying are
completely hidden behind computation if computation time of a layer is longer
than host-device communication time. Bulk offloading with delay and bulk reloading
with prefetch are implemented."""
def __init__( # Wait for the moment the tensor is ready to be offloaded.
self, self.offload_stream.wait_event(self.fwd_gpu_tensor_group.events[tensor_id]) # type: ignore[arg-type]
num_offload_group, # must be <= actual number of groups (number of commits)
num_model_group, with torch.cuda.stream(self.offload_stream):
tensor_need_offloading_checker=(lambda t: True), if allocate_cpu_buffers:
double_buffering=False, # empty_like is defined also for QuantizedTensors
debug=False, offloaded_tensor = torch.empty_like(
) -> None: tensor, device=torch.device("cpu"), pin_memory=True
super().__init__(
num_offload_group=num_offload_group,
tensor_need_offloading_checker=tensor_need_offloading_checker,
debug=debug,
) )
# Number of layers in the model self.cpu_tensor_group.tensor_list.append(offloaded_tensor)
self.num_layers = num_model_group
# Data Structure to maintain reference to activation tensors
self.tensor_tag_to_buf = {}
# Data structure to hold the FP8/MXFP8 tensor objects
self.fp8_tensor_object_map = {}
self.float8_transpose_cache_valid = {}
self.dereferencing_list = []
# Tracking the number of layers offloaded
self.offloaded_group_count = 0
# Core data structure that decides the window for offloading
self.layer_window_map = {}
# Data structures fo double buffered reloading
self.double_buffering = double_buffering
self.reload_double_buffer = [[], []]
self.double_buffer_created = False
# Logic to make offloading load balance across computation
# for optimal CPU/GPU interconnect usage
constant = 0
for i in range(self.num_offload_group):
self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1
if i < (self.num_layers % self.num_offload_group):
self.layer_window_map[i] += i + 1
constant = i + 1
else: else:
self.layer_window_map[i] += constant assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, (
"CPU buffer shape does not match the offloaded tensor shape:"
f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} "
" Make sure that tensor shaped do not change between"
" iterations if retain_pinned_cpu_buffers is True."
)
offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id]
offloaded_tensor.copy_(tensor, non_blocking=True)
# allocate streams and events for synchronization # aux is a dictionary that contains auxiliary data like information which tensors were deduplicated,
self.d2h_stream = torch.cuda.Stream() # needed to restore pre-offload state after reload.
self.h2d_stream = torch.cuda.Stream() self.aux = aux
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: self.finish_offload_event = torch.cuda.Event()
global CPUOffloadedLayer self.finish_offload_event.record(self.offload_stream)
torch_stray_tensor = isinstance( def release_activation_forward_gpu_memory(self):
tensor, """
( Release GPU memory of the activations.
torch._subclasses.fake_tensor.FakeTensor, Waits for offload to finish - memory needs to be kept alive when GPU->CPU copy is performed.
torch._subclasses.functional_tensor.FunctionalTensor, """
), self._validate_state(
func_name="release_activation_forward_gpu_memory", allowed_states=["offload_started"]
) )
self.state = "offload_finished"
torch.cuda.current_stream().wait_event(self.finish_offload_event) # type: ignore[arg-type]
is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage) # GPU memory can be released safely after the offload.
# Notice that the memory needs to be kept alive when GPU->CPU copy is performed.
self.fwd_gpu_tensor_group = TensorGroup()
del self.finish_offload_event
if not torch_stray_tensor: def start_reload(self):
"""
Start reloading of tensors.
It allocates new tensors on GPU and puts copy from CPU tasks on offload stream.
"""
self._validate_state(func_name="start_reload", allowed_states=["offload_finished"])
self.state = "reload_started"
# obtain a unique tensor tag self.bwd_gpu_tensor_group = TensorGroup()
tensor_tag = (self.current_group, self.tensor_count_current_group) for tensor in self.cpu_tensor_group.tensor_list:
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state # Notice that reloaded tensor is allocated on main stream,
# not offloaded stream. It is because PyTorch memory allocator
# cannot move tensors from pool of one stream to another without
# calling cudaFree and cudaMalloc again.
if is_quantized_tensor: # empty_like is defined also for QuantizedTensors.
tensor_list, _ = tensor.prepare_for_saving() reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda"))
self.offload_stream.wait_stream(torch.cuda.current_stream())
self.tensor_tag_to_state[tensor_tag] = [] with torch.cuda.stream(self.offload_stream):
self.tensor_tag_to_buf[tensor_tag] = [] reloaded_tensor.copy_(tensor, non_blocking=True)
# Added support for de-duplicating FP8 param tensors reload_tensor_event = torch.cuda.Event()
for _, value in self.fp8_tensor_object_map.items(): reload_tensor_event.record(self.offload_stream)
if tensor is value: self.bwd_gpu_tensor_group.events.append(reload_tensor_event)
self.dereferencing_list.append(tensor_tag) self.bwd_gpu_tensor_group.tensor_list.append(reloaded_tensor)
break
self.fp8_tensor_object_map[tensor_tag] = tensor self.bwd_gpu_tensor_group.aux = self.aux
if isinstance(tensor, Float8Tensor): self.bwd_gpu_tensor_group = TensorGroupProcessor.tensor_group_process_after_reload(
self.float8_transpose_cache_valid[tensor_tag] = getattr( self.bwd_gpu_tensor_group
tensor, "_transpose_invalid"
) )
else:
tensor_list = [tensor]
for t in tensor_list: def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
if is_quantized_tensor: """
self.tensor_tag_to_state[tensor_tag].append(t) It is called when a tensor is saved for backward pass.
If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group.
If tensor is not offloaded, returns the tensor itself.
"""
self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"])
if self._check_if_offload(tensor):
self.fwd_gpu_tensor_group.tensor_list.append(tensor)
# The group is processed and offloaded at the end of the forward pass of current layer.
# To enable offloading of tensors faster we use self.offload_stream and record
# the events when the tensors are ready to be offloaded.
# It means that we do not need to wait to the end of current layer to start offloading.
if hasattr(tensor, "start_reload_event"):
self.fwd_gpu_tensor_group.events.append(tensor.start_reload_event)
else: else:
self.tensor_tag_to_state[tensor_tag] = t self.fwd_gpu_tensor_group.events.append(torch.cuda.Event())
self.fwd_gpu_tensor_group.events[-1].record(torch.cuda.current_stream())
return len(self.fwd_gpu_tensor_group.tensor_list) - 1
return tensor
def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor:
"""
It is called when a tensor is used in backward pass.
Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish.
"""
self._validate_state(
func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"]
)
# 1. tensor not offloaded
if isinstance(tensor_or_tensor_id, torch.Tensor):
return tensor_or_tensor_id
# 2. the layer was not offloaded at all
if self.state == "not_offloaded":
return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id]
# 3. the layer was offloaded
assert self.state == "reload_started"
# wait for the tensor to be reloaded
torch.cuda.current_stream().wait_event(
self.bwd_gpu_tensor_group.events[tensor_or_tensor_id]
)
return self.bwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id]
def release_all_memory(self):
"""Release all gpu and cpu memory the state stored. Is called after the backward pass."""
self.fwd_gpu_tensor_group = TensorGroup()
if not self.retain_pinned_cpu_buffers:
self.cpu_tensor_group = TensorGroup()
self.bwd_gpu_tensor_group = TensorGroup()
self.state = "not_offloaded"
def _check_if_offload(self, t: torch.Tensor) -> bool:
"""
Check if tensor needs to be offloaded.
"""
if ( if (
self.current_group < self.num_offload_group not isinstance(t, torch.nn.Parameter)
and self.tensor_need_offloading_checker(t) and not getattr(t, "_TE_do_not_offload", False)
and not isinstance(t, torch._subclasses.FakeTensor)
and t.device.type == "cuda"
): ):
if is_quantized_tensor: if not t.is_contiguous() and not getattr(t, "offload_base_tensor", False):
self.tensor_tag_to_buf[tensor_tag].append(t) warnings.warn(
# Need to clear the internal data reference for the quantized tensors "Tried to offload non-contiguous tensor, which is not supported. Offload of"
tensor.clear() " this tensor will be skipped."
else: )
self.tensor_tag_to_buf[tensor_tag] = t return False
# Needed to differentiate non offloaded layer's attention return True
# QKV layout of attention of non-offloaded layer needs return False
# to be modified while reloading
CPUOffloadedLayer = True def get_offloaded_total_size_mb(self) -> float:
else: """
tensor_tag = (-1, self.torch_tensor_count) Get total size of offloaded tensors in MB, used only for testing.
self.torch_tensor_count += 1 """
self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag def get_tensor_size_mb(tensor):
if tensor is None:
return 0
if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage):
return sum(get_tensor_size_mb(t) for t in tensor.get_data_tensors())
return tensor.numel() * tensor.element_size() / (1024**2)
def tensor_pop(self, tensor_tag, **kwargs): total_size = 0
"""Tensor pop.""" for tensor in self.cpu_tensor_group.tensor_list:
global CPUOffloadedLayer total_size += get_tensor_size_mb(tensor)
return total_size
assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag)
# Handling the quantized tensor case specially here class OffloadSynchronizer:
if isinstance(tensor, list): """
# If it's a duplicated tensor, we don't need to locally Base class responsible for synchronizing offloading and reloading of tensors for multiple layers.
# write back a tensor as it would already be written In base class we only track layer number and
if tensor_tag in self.dereferencing_list: create OffloadableLayerState instances for all layers, but do not start offloading or reloading.
self.dereferencing_list.remove(tensor_tag) """
def __init__(
self,
num_layers: int,
retain_pinned_cpu_buffers: bool = False,
offload_stream: Optional[torch.cuda.Stream] = None,
):
self.num_layers = num_layers
self.offload_stream = offload_stream if offload_stream is not None else torch.cuda.Stream()
self.layer_states = {
i: OffloadableLayerState(self.offload_stream, retain_pinned_cpu_buffers)
for i in range(num_layers)
}
self.num_of_fwds = None
self.previous_bwd_layer_id = None
self.current_layer_id = None
def fwd_step(self) -> int:
"""
Invoked before each layer forward.
"""
if self.num_of_fwds in [None, self.num_layers - 1]:
# reset the offload synchronizer
self.num_of_fwds = 0
else: else:
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) self.num_of_fwds += 1
tensor = self.fp8_tensor_object_map.pop(tensor_tag) self.current_layer_id = self.num_of_fwds
return self.current_layer_id
if self.double_buffering: def bwd_step(self, layer_num: int):
tensor._do_not_clear = True """
Invoked before each layer backward.
"""
if self.previous_bwd_layer_id is not None:
self.layer_states[self.previous_bwd_layer_id].release_all_memory()
self.previous_bwd_layer_id = layer_num
self.current_layer_id = layer_num
self.tensor_tag_to_buf.pop(tensor_tag, None) def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
# the tensor should have been copied back in on_group_commit_backward() """Default push tensor method"""
# which invokes bulk_reload_group. return self.layer_states[self.num_of_fwds].push_tensor(tensor)
assert not isinstance(tensor, tuple)
return tensor
def bulk_offload_group(self, group_to_offload): def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor:
"""Bulk offload group.""" """Default pop tensor method"""
with torch.cuda.stream(self.d2h_stream): return self.layer_states[self.current_layer_id].pop_tensor(tensor_or_tensor_id)
for tensor_tag, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_tag
if group_id == group_to_offload:
assert not isinstance(state, tuple)
is_quantized_tensor = isinstance(state, list) def finish_part_of_bwd(self):
"""
We need to release memory of backward - this call does that.
It needs to be invoked after every backward pass - there may be
more than one in pipeline parallelism.
if is_quantized_tensor: It is needed, because call bwd_step is invoked before each layer backward,
tensor_list = state but we need to release memory after the backward pass is finished.
self.tensor_tag_to_state[tensor_tag] = [] """
else: if self.previous_bwd_layer_id is not None:
tensor_list = [state] self.layer_states[self.previous_bwd_layer_id].release_all_memory()
self.previous_bwd_layer_id = None
for tensor_on_device in tensor_list:
# `tensor_offloaded` is a hacky way of dealing with columnwise-only
# quantized tensors for CPU offloading. The complication is due to
# the `rowwise_data` being `None`. The offloading checker incorrectly
# returns `False` and the entire `state` ([None, columnwise_tensor])
# is added to the tensor tag state dict. A better design would change
# how quantized tensors are kept track of in the offload handler.
# Currently at every stage it is ensured that a quantized tensor is a
# list whereas a non-quantized tensor is standalone object, which is
# not good! TODO(@sanandaraj5597)
tensor_offloaded = False
# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
tensor_offloaded = True
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
if is_quantized_tensor:
if tensor_offloaded:
self.tensor_tag_to_state[tensor_tag].append(state)
else:
self.tensor_tag_to_state[tensor_tag].append(tensor_on_device)
else:
self.tensor_tag_to_state[tensor_tag] = state
def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward."""
global CPUOffloadedLayer
# For the first group, kickstart the offload after we have
# the first compute completion
if current_group == 0:
self.d2h_stream.wait_stream(torch.cuda.current_stream())
if not self.double_buffer_created:
# Creating the first copy of double buffer for tensors that are offloaded
for tensor_tag, buf in self.tensor_tag_to_buf.items():
if isinstance(buf, list):
for b in buf:
self.reload_double_buffer[0].append(
torch.empty_like(b) if self.double_buffering else None
)
else:
self.reload_double_buffer[0].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.bulk_offload_group(current_group) def get_offloaded_total_size_mb(self) -> float:
"""
# Window map data structure helps us synchronize based on number Get total size of offloaded tensors in MB, used only for testing.
# of layers offloaded """
if self.layer_window_map[self.offloaded_group_count] == current_group: return sum(
self.layer_states[layer_id].get_offloaded_total_size_mb()
# Stream synchronization both ways for layer_id in self.layer_states
self.d2h_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().wait_stream(self.d2h_stream)
# Time to free the activation memory after usage
for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorStorage class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
# Release the pointer to the tensor
self.tensor_tag_to_buf[tensor_tag] = None
# Time to offload the next group
if self.offloaded_group_count < (self.num_offload_group - 1):
self.bulk_offload_group(self.offloaded_group_count + 1)
# Increment the offload group count to keep track
self.offloaded_group_count += 1
if current_group == (self.num_offload_group - 1):
CPUOffloadedLayer = False
if not self.double_buffer_created:
# Creating second copy of double buffer for tensors that are offloaded
if current_group == (self.num_layers - 1):
for buf in self.reload_double_buffer[0]:
self.reload_double_buffer[1].append(
torch.empty_like(buf) if self.double_buffering else None
) )
self.double_buffer_created = True
def on_group_commit_forward(self):
"""This function will cause host device synchronization"""
# handle synchronization events
self.synchronize_on_group_commit_forward(self.current_group)
super().on_group_commit_forward() class DefaultOffloadSynchronizer(OffloadSynchronizer):
"""
Default implementation of OffloadSynchronizer,
intended to be used in standard training workloads - with multiple forwards
and multiple backwards.
"""
def bulk_reload_group(self, group_to_reload): def __init__(
"""Bulk reload group.""" self,
assert group_to_reload < self.num_offload_group num_layers: int,
num_offloaded_layers: int | None = None,
retain_pinned_cpu_buffers: bool = False,
offload_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__(num_layers, retain_pinned_cpu_buffers, offload_stream)
buffer_idx = 0 # map of layers to bool meaning if layer needs to be offloaded
double_buffer_idx = group_to_reload % 2 self.offload_layer_map: dict[int, bool] = {}
main_stream = torch.cuda.current_stream() # num_layer: int -> list of layers that need to finish offload by this moment
self.finish_offload_map: defaultdict[int, list[int]] = defaultdict(list)
# num_layer: int -> list of layers that need to start reload in this moment
self.start_reload_map: defaultdict[int, list[int]] = defaultdict(list)
with torch.cuda.stream(self.h2d_stream): self._init_offload_synchronization_dicts(num_offloaded_layers)
# move back tensors
for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label
if group_id == group_to_reload:
if isinstance(state, tuple): def _init_offload_synchronization_dicts(self, num_offloaded_layers: int):
if self.double_buffering: """
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx] If synchronization dictionary is not provided, the number of offloaded layers is used to initialize
else: offload_layer_map, finish_offload_map and start_reload_map.
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state[1], device=torch.cuda.current_device()
)
recovered_tensor = SynchronizedGroupOffloadHandler.reload( The aim is to minimize memory usage by the end of the forward pass.
state, True, reload_buffer
) The optimal strategy for that is to offload layers 0, ..., num_offloaded_layers - 1.
buffer_idx = buffer_idx + 1 For layer i offload needs to finish before num_layers - num_offloaded_layers + i.
self.tensor_tag_to_state[tensor_label] = recovered_tensor For layer i reload needs to start after num_layers - num_offloaded_layers + i.
elif isinstance(state, list):
tensor_list = []
for state_tuple in state:
if isinstance(state_tuple, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][
buffer_idx
]
else:
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state_tuple[1], device=torch.cuda.current_device()
)
tensor_list.append( This ensures that - if all layers have memory footprint of T - then peak memory usage of saving activations is
SynchronizedGroupOffloadHandler.reload( (num_layers - num_offloaded_layers) * T.
state_tuple, """
True, for layer_id in range(self.num_layers):
reload_buffer, if layer_id < num_offloaded_layers:
self.offload_layer_map[layer_id] = True
self.finish_offload_map[self.num_layers - num_offloaded_layers + layer_id].append(
layer_id
) )
self.start_reload_map[self.num_layers - 1 - num_offloaded_layers + layer_id].append(
layer_id
) )
buffer_idx = buffer_idx + 1
else: else:
tensor_list.append(state_tuple) self.offload_layer_map[layer_id] = False
# No need to write back the duplicated tensor againn def fwd_step(self) -> int:
# to the same location, this check ensures that """
if tensor_label in self.dereferencing_list: Invoked before each layer forward.
self.dereferencing_list.remove(tensor_label) """
else: super().fwd_step()
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved( if self.offload_layer_map.get(self.current_layer_id - 1, False):
tensor_list self.layer_states[self.current_layer_id - 1].start_offload()
)
if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): for layer in self.finish_offload_map[self.current_layer_id]:
self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( self.layer_states[layer].release_activation_forward_gpu_memory()
self.float8_transpose_cache_valid.pop(tensor_label) return self.current_layer_id
)
self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( def bwd_step(self, layer_num: int):
tensor_label """
) Invoked before each layer backward.
"""
super().bwd_step(layer_num)
def on_group_commit_backward(self): for layer in self.start_reload_map[layer_num]:
# first decrement the current group. self.layer_states[layer].start_reload()
# after last commit in forward, the group will +1; in backward it -1.
# Finally it should be decremented to 0.
self.current_group -= 1
assert self.current_group >= 0
# Layer window data structure helps us to reload at right times
if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group:
# Stream synchronization both ways class ManualOffloadSynchronizer(OffloadSynchronizer):
self.h2d_stream.wait_stream(torch.cuda.current_stream()) """
torch.cuda.current_stream().wait_stream(self.h2d_stream) Manual implementation of OffloadSynchronizer,
all synchronization is done manually by the user by using
one of the following methods:
- start_offload_layer
- release_activation_forward_gpu_memory
- start_reload_layer
This implementation is intended to be used in more complex trainigs workflows.
It is useful for example in pipeline parallelism.
"""
# Time to reload the next group def start_offload_layer(self, layer_id: int):
self.bulk_reload_group(self.offloaded_group_count - 1) """
Start offloading of the layer.
Each tensor GPU->CPU copy is done asynchronously on the offload stream.
Start of each copy is started after tensor_push() is called on the current stream.
"""
self.layer_states[layer_id].start_offload()
# Decrease the offloading group counter def release_activation_forward_gpu_memory(self, layer_id: int):
self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 """
Release memory of the activations of the layer.
It waits for the offload of the layer to finish.
"""
self.layer_states[layer_id].release_activation_forward_gpu_memory()
# Last group computation needs to wait till all the reloads complete def start_reload_layer(self, layer_id: int):
if self.current_group == 0: """
torch.cuda.current_stream().wait_stream(self.h2d_stream) Start reloading of the layer.
self.offloaded_group_count = 0 Each tensor reload is awaited to finish before tensor_pop() for that tensor is called on the current stream.
"""
self.layer_states[layer_id].start_reload()
def get_cpu_offload_context( def get_cpu_offload_context(
enabled: bool = False, enabled: bool = False,
num_layers: int = 1, num_layers: Optional[int] = 1,
model_layers: int = 1, model_layers: int = 1,
offload_activations: bool = True, offload_activations: bool = True,
offload_weights: bool = False, offload_weights: bool = False,
double_buffering: bool = False, double_buffering: bool = False, # pylint: disable=unused-argument
manual_synchronization: bool = False,
retain_pinned_cpu_buffers: bool = False,
offload_stream: Optional[torch.cuda.Stream] = None,
): ):
""" """
This function returns the CPU Offload context and the synchronizer function that needs to be CPU Offloading feature for seqeuences of layers. Can be used for arbitrary layers, not necessarily
used after every transformer layer. Returns `nullcontext()` if offloading is not enabled. for these provided by the TE.
Usage: Usage:
.. code-block:: python .. code-block:: python
cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True) cpu_offload_context, sync_function = get_cpu_offload_context(...)
for _ in range(num_layers):
with cpu_offload_context: with cpu_offload_context:
te_layer.forward(inp_tensor) x = layers[i].forward(x)
cpu_offload_synchronizer() x = sync_function(x)
Parameters Parameters
---------- ----------
enabled: bool, default = `False` enabled: bool, default = `False`
When set to True, CPU Offloading functionality is enabled. When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1 num_layers: int, default = 1
Determines the number of transformer layers Determines the number of layers
you want to offload activations/weights for. you want to offload activations/weights for.
model_layers: int, default = 1 model_layers: int, default = 1
Number of layers in the model that will be used under this context. Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True` offload_activations: bool, default = `True`
When set to `True`, offloads the activations for the TE layer. Deprecated.
offload_weights: bool, default = `True` offload_weights: bool, default = `True`
When set to `True`, offloads the weights for the TE layer. Deprecated.
double_buffering: bool, default = `False` double_buffering: bool, default = `False`
When set to `True`, uses double buffering for offloading. Deprecated.
retain_pinned_cpu_buffers: bool, default = `False`
If True, the pinned CPU buffers are retained after offloading
and reused for the next iteration. It is useful for cuda graphs capture.
manual_synchronization: bool, default = `False`
If True, the synchronization is done manually by the user.
Additional argument manual_controller is returned. See more in manual control section.
offload_stream: torch.cuda.Stream, default = `None`
If provided, the offload stream is used for offloading and reloading.
Otherwise, a new stream is allocated internally. It can be other than None
only if manual_synchronization is True.
Manual synchronization
----------
By default, layers are offloaded/reloaded asynchronously
with respect to the current forward/backward stream with predefined synchronization,
to ensure that activation memory usage is equal to
`(num_layers - num_offloaded_layers) * T`, where `T` is the memory footprint of a layer.
For more control over the offloading and reloading process, you can set `manual_synchronization=True`.
In this case, an additional argument, `manual_controller`, is returned.
The `manual_controller` provides the following methods:
- `start_offload_layer(layer_id: int)`
- `release_activation_forward_gpu_memory(layer_id: int)`
- `start_reload_layer(layer_id: int)`
If none of these methods are invoked for a given layer, that layer will not be offloaded or reloaded.
If `start_offload_layer()` is called for a layer, offload copies for that layer begin asynchronously on the offload stream.
Since GPU activations must be kept in memory until the copy is finished, pointers to all activations are stored.
To release this memory, you need to call `release_activation_forward_gpu_memory(layer_id)`.
This method makes the current stream wait for an event recorded on the offload stream after all tensors from the layer have been offloaded.
The `start_reload_layer()` method is used to start reloading a layer.
Each tensor reload is awaited to finish before `tensor_pop()` for that tensor is called on the current stream.
You can provide an `offload_stream` to be used for offload and reload operations.
This allows for more detailed synchronization, such as delaying the start of offloading.
Example:
.. code-block:: python
offload_stream = torch.cuda.Stream()
cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
enabled=True, model_layers=num_layers, manual_synchronization=True, offload_stream=offload_stream)
for i in range(num_layers):
with cpu_offload_context:
out[i] = layers[i].forward(inp[i])
out[i] = sync_function(out[i])
manual_controller.start_offload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
manual_controller.release_activation_forward_gpu_memory(i)
for i in range(num_layers - 1, -1, -1):
manual_controller.start_reload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
out[i].sum().backward()
V1 code path
----------
If you want to use the v1 code path for offloading,
please set the environment variable NVTE_CPU_OFFLOAD_V1 to 1.
""" """
if NVTE_CPU_OFFLOAD_V1:
return v1_code_path.get_cpu_offload_context(
enabled=enabled,
num_layers=num_layers,
model_layers=model_layers,
offload_activations=offload_activations,
offload_weights=offload_weights,
double_buffering=double_buffering,
)
if not offload_weights and not offload_activations: if not offload_weights and not offload_activations:
raise ValueError( raise ValueError(
...@@ -703,8 +753,6 @@ def get_cpu_offload_context( ...@@ -703,8 +753,6 @@ def get_cpu_offload_context(
) )
if offload_weights: if offload_weights:
import warnings
warnings.warn( warnings.warn(
"Offloading weights is deprecated. Using offload_weights=True does not have any" "Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect.", " effect.",
...@@ -713,26 +761,100 @@ def get_cpu_offload_context( ...@@ -713,26 +761,100 @@ def get_cpu_offload_context(
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing. # Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if not offload_activations: if not offload_activations:
return nullcontext(), lambda x: x return contextlib.nullcontext(), lambda x: x
def tensor_need_offloading_checker_activations(tensor): if TEDebugState.debug_enabled:
return hasattr(tensor, "activation_offloading") raise RuntimeError("CPU offload is not supported in debug mode.")
if not manual_synchronization:
assert (
num_layers <= model_layers - 1
), "Cannot offload all layers without manual synchronization - last layer is not offloaded."
if num_layers == model_layers - 1:
warnings.warn(
"Offloading num_layers == model_layers - 1 is not recommended, it prevents"
" overlapping of computation and offload/reload."
)
tensor_need_offloading_checker = tensor_need_offloading_checker_activations assert (
offload_stream is None or manual_synchronization
), "offload_stream can be provided only if manual_synchronization is True"
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( if manual_synchronization:
num_offload_group=num_layers, offload_synchronizer = ManualOffloadSynchronizer(
num_model_group=model_layers, model_layers, retain_pinned_cpu_buffers, offload_stream
tensor_need_offloading_checker=tensor_need_offloading_checker, )
double_buffering=double_buffering, else:
offload_synchronizer = DefaultOffloadSynchronizer(
model_layers,
num_layers,
retain_pinned_cpu_buffers,
offload_stream,
)
class _CpuOffloadContext(contextlib.ContextDecorator):
def __init__(self):
self.current_layer = None
self.previous_offload_synchronizer = None
self.offload_synchronizer = offload_synchronizer
self.inside_context = False
def __enter__(self):
assert (
self.inside_context is False
), "Offloading context was entered without synchronization function being called."
self.inside_context = True
self._hooks_ctx = saved_tensors_hooks(
offload_synchronizer.push_tensor, offload_synchronizer.pop_tensor
)
self._hooks_ctx.__enter__()
global OFFLOAD_SYNCHRONIZER
self.previous_offload_synchronizer = OFFLOAD_SYNCHRONIZER
OFFLOAD_SYNCHRONIZER = offload_synchronizer
self.current_layer = offload_synchronizer.fwd_step()
return self
def __exit__(self, *args):
self._hooks_ctx.__exit__(*args)
global OFFLOAD_SYNCHRONIZER
OFFLOAD_SYNCHRONIZER = self.previous_offload_synchronizer
self.inside_context = False
def synchronization_function(self, tensor):
"""
This function is used to catch the backward pass of the model.
"""
assert tensor.requires_grad is True
assert self.current_layer is not None
cur_layer = self.current_layer
assert (
self.inside_context is False
), "Synchronization function was called without offloading context being entered."
def hook(_):
# offload_synchronizer.finish_part_of_bwd needs
# to be called after every backward pass - there may be
# more than one in pipeline parallelism.
torch.autograd.variable.Variable._execution_engine.queue_callback(
offload_synchronizer.finish_part_of_bwd
) )
offload_synchronizer.bwd_step(cur_layer)
tensor.grad_fn.register_prehook(hook)
return tensor
def group_prefetch_offload_commit_async(tensor): cpu_offload_context = _CpuOffloadContext()
return group_prefetch_offload_commit(tensor, cpu_offload_handler)
if enabled: if enabled:
if manual_synchronization:
return (
cpu_offload_context,
cpu_offload_context.synchronization_function,
offload_synchronizer,
)
return ( return (
CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), cpu_offload_context,
group_prefetch_offload_commit_async, cpu_offload_context.synchronization_function,
) )
return nullcontext(), group_prefetch_offload_commit_async return contextlib.nullcontext(), lambda x: x
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Functionality for CPU offloading of tensors saved for backward pass."""
from __future__ import annotations
from contextlib import nullcontext
from typing import Any, Dict, Optional
import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from .quantized_tensor import QuantizedTensorStorage
from .tensor.float8_tensor import Float8Tensor
__all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False
CPUOffloadedLayer = False
def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor."""
if TEDebugState.debug_enabled:
raise RuntimeError("CPU offload is not supported in debug mode.")
for tensor in tensors:
if tensor is None:
continue
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
tensor.activation_offloading = True
else:
data_tensors = tensor.get_data_tensors()
for tensor in data_tensors:
if tensor is not None:
tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorStorage classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True
def is_cpu_offload_enabled() -> bool:
"""Check if CPU offloading is currently enabled."""
return CPUOffloadEnabled
def is_current_layer_offloaded() -> bool:
"""Check if current layers is being offloaded."""
return CPUOffloadedLayer
class CpuOffloadSavedTensorHook:
"""Contex-manager that executes a pair of pack/unpack hooks for saved tensors.
In this context, the ``on_save_for_backward`` method will be called every time
a tensor is saved for backward (this includes intermediary results saved using
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
also those recorded by a PyTorch-defined operation).
The ``on_get_saved_tensors`` method will be called when the backward function
of this op attempts to retrieve the saved tensor from context (this includes
:func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the
as input the return value of the ``on_save_for_backward``, and is meant to return
an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of
size, device and element values.
Example:
>>> import torch
>>> from typing import Any
>>>
>>> class DummyHook(CpuOffloadSavedTensorHook):
...
... def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
... logging.info("On save", tensor)
... return (tensor,)
...
... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
... logging.info("On get", saved_state)
... tensor, = saved_state
... return tensor
...
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with DummyHook():
... y = a * b
...
On save tensor([1., 1., 1., 1., 1.], requires_grad=True)
On save tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),)
On get (tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>),)
"""
def __init__(self) -> None:
self.inside_context = False
def __enter__(self):
global CPUOffloadEnabled
CPUOffloadEnabled = True
self.inside_context = True
torch._C._autograd._push_saved_tensors_default_hooks(
self.on_save_for_backward, self.on_get_saved_tensor
)
def __exit__(self, *args: Any):
global CPUOffloadEnabled
CPUOffloadEnabled = False
self.inside_context = False
torch._C._autograd._pop_saved_tensors_default_hooks()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
"""On save for backward."""
raise NotImplementedError(
"`on_save_for_backward: Callable[[torch.Tensor], Any]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
"""On get saved tensor."""
raise NotImplementedError(
"`on_get_saved_tensors: Callable[[Any], torch.Tensor]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
"""Context-manager that offloads/recovers tensors through an offload hander.
The hook just offloads/recovers the tensor object to the handler through `tensor_push`
and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
or prefetching timing is transparent to this hook.
"""
def __init__(
self,
offload_handler: OffloadHandler,
handler_extra_kwargs: Optional[Dict[str, Any]] = None,
debug: bool = False,
) -> None:
if handler_extra_kwargs is None:
handler_extra_kwargs = {}
self.debug: bool = debug
self.offload_handler: OffloadHandler = offload_handler
self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs
super().__init__()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)
return retrieve_identifier
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs)
return tensor
class OffloadHandler:
"""A base class for CPU offload-handler."""
def __init__(self) -> None:
pass
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
"""Tensor push."""
raise NotImplementedError(
"`tensor_push is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_push."
)
def tensor_pop(self, tensor_tag: Any, **kwargs):
"""Tensor pop."""
raise NotImplementedError(
"`tensor_pop is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_pop."
)
class GroupCommitFunction(torch.autograd.Function):
"""this is a dummy op with output identical to input.
However, it is necessary for marking a timepoint for offload handler to
accomplish all synchronizations. Implementing it as a function is necessary
because we need to actions in both forward and backward.
"""
@staticmethod
def forward(ctx, tensor, cpu_offload_handler):
# pylint: disable=missing-function-docstring
cpu_offload_handler.on_group_commit_forward()
ctx.cpu_offload_handler = cpu_offload_handler
# return the identical tensor
return tensor
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
cpu_offload_handler = ctx.cpu_offload_handler
cpu_offload_handler.on_group_commit_backward()
return grad_output, None
group_prefetch_offload_commit = GroupCommitFunction.apply
class SynchronizedGroupOffloadHandler(OffloadHandler):
"""Offload Handler that offloads/reloads in a synchronized way.
The device-to-host and host-to-device copying happen in the same stream
as the computation kernels, thus the copying will block computation.
"""
def __init__(
self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False
) -> None:
super().__init__()
self.num_offload_group = num_offload_group
self.tensor_need_offloading_checker = tensor_need_offloading_checker
self.debug = debug
self.groupid_reset()
def groupid_reset(self):
"""Groupid reset."""
# Data structures to label saved tensors and book-keep their cpu copies.
# Currently, on push, create a new cpu tensor and copies; on pop, copies
# the tensor back to gpu and deletes the cpu tensor.
# These will increment whenever `group_commit()` is invoked
self.current_group, self.tensor_count_current_group = (0, 0)
self.torch_tensor_count = 0
self.tensor_tag_to_state = {}
def on_group_commit_forward(self):
"""On group commit forward."""
# finishing up with updating current group and tensor count
self.current_group += 1 # increment
self.tensor_count_current_group = 0 # reset
def on_group_commit_backward(self):
"""On group commit backward."""
self.current_group -= 1
assert self.current_group >= 0
@staticmethod
def offload(src_tensor, pin_memory=True):
"""Offload."""
cpu_backup = torch.empty(
src_tensor.size(),
dtype=src_tensor.dtype,
layout=src_tensor.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
state = (src_tensor.device, cpu_backup)
return state
@staticmethod
def reload(state, non_blocking=None, copy_buffer=None):
"""Reload."""
dev, cpu_backup = state
if non_blocking is None:
non_blocking = cpu_backup.is_pinned()
if copy_buffer is None:
return cpu_backup.to(dev, non_blocking=non_blocking)
assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!"
copy_buffer.copy_(cpu_backup, non_blocking=non_blocking)
return copy_buffer
def tensor_push(self, tensor: torch.Tensor, **kwargs):
"""Tensor push."""
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(
tensor
):
state = SynchronizedGroupOffloadHandler.offload(tensor)
self.tensor_tag_to_state[tensor_tag] = state
else:
# will be offloaded together after group commit
self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag
def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
assert tensor_tag in self.tensor_tag_to_state
state = self.tensor_tag_to_state.pop(tensor_tag)
if isinstance(state, tuple):
tensor = SynchronizedGroupOffloadHandler.reload(state)
else:
tensor = state
return tensor
class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
"""Compared to synchronize, this uses more memory because of the buffer but
achieves better performance due to the overlapping. D2h and h2d copying are
completely hidden behind computation if computation time of a layer is longer
than host-device communication time. Bulk offloading with delay and bulk reloading
with prefetch are implemented."""
def __init__(
self,
num_offload_group, # must be <= actual number of groups (number of commits)
num_model_group,
tensor_need_offloading_checker=(lambda t: True),
double_buffering=False,
debug=False,
) -> None:
super().__init__(
num_offload_group=num_offload_group,
tensor_need_offloading_checker=tensor_need_offloading_checker,
debug=debug,
)
# Number of layers in the model
self.num_layers = num_model_group
# Data Structure to maintain reference to activation tensors
self.tensor_tag_to_buf = {}
# Data structure to hold the FP8/MXFP8 tensor objects
self.fp8_tensor_object_map = {}
self.float8_transpose_cache_valid = {}
self.dereferencing_list = []
# Tracking the number of layers offloaded
self.offloaded_group_count = 0
# Core data structure that decides the window for offloading
self.layer_window_map = {}
# Data structures fo double buffered reloading
self.double_buffering = double_buffering
self.reload_double_buffer = [[], []]
self.double_buffer_created = False
# Logic to make offloading load balance across computation
# for optimal CPU/GPU interconnect usage
constant = 0
for i in range(self.num_offload_group):
self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1
if i < (self.num_layers % self.num_offload_group):
self.layer_window_map[i] += i + 1
constant = i + 1
else:
self.layer_window_map[i] += constant
# allocate streams and events for synchronization
self.d2h_stream = torch.cuda.Stream()
self.h2d_stream = torch.cuda.Stream()
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
global CPUOffloadedLayer
torch_stray_tensor = isinstance(
tensor,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
)
is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage)
if not torch_stray_tensor:
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if is_quantized_tensor:
tensor_list, _ = tensor.prepare_for_saving()
self.tensor_tag_to_state[tensor_tag] = []
self.tensor_tag_to_buf[tensor_tag] = []
# Added support for de-duplicating FP8 param tensors
for _, value in self.fp8_tensor_object_map.items():
if tensor is value:
self.dereferencing_list.append(tensor_tag)
break
self.fp8_tensor_object_map[tensor_tag] = tensor
if isinstance(tensor, Float8Tensor):
self.float8_transpose_cache_valid[tensor_tag] = getattr(
tensor, "_transpose_invalid"
)
else:
tensor_list = [tensor]
for t in tensor_list:
if is_quantized_tensor:
self.tensor_tag_to_state[tensor_tag].append(t)
else:
self.tensor_tag_to_state[tensor_tag] = t
if (
self.current_group < self.num_offload_group
and self.tensor_need_offloading_checker(t)
):
if is_quantized_tensor:
self.tensor_tag_to_buf[tensor_tag].append(t)
# Need to clear the internal data reference for the quantized tensors
tensor.clear()
else:
self.tensor_tag_to_buf[tensor_tag] = t
# Needed to differentiate non offloaded layer's attention
# QKV layout of attention of non-offloaded layer needs
# to be modified while reloading
CPUOffloadedLayer = True
else:
tensor_tag = (-1, self.torch_tensor_count)
self.torch_tensor_count += 1
self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag
def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
global CPUOffloadedLayer
assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag)
# Handling the quantized tensor case specially here
if isinstance(tensor, list):
# If it's a duplicated tensor, we don't need to locally
# write back a tensor as it would already be written
if tensor_tag in self.dereferencing_list:
self.dereferencing_list.remove(tensor_tag)
else:
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
tensor = self.fp8_tensor_object_map.pop(tensor_tag)
if self.double_buffering:
tensor._do_not_clear = True
self.tensor_tag_to_buf.pop(tensor_tag, None)
# the tensor should have been copied back in on_group_commit_backward()
# which invokes bulk_reload_group.
assert not isinstance(tensor, tuple)
return tensor
def bulk_offload_group(self, group_to_offload):
"""Bulk offload group."""
with torch.cuda.stream(self.d2h_stream):
for tensor_tag, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_tag
if group_id == group_to_offload:
assert not isinstance(state, tuple)
is_quantized_tensor = isinstance(state, list)
if is_quantized_tensor:
tensor_list = state
self.tensor_tag_to_state[tensor_tag] = []
else:
tensor_list = [state]
for tensor_on_device in tensor_list:
# `tensor_offloaded` is a hacky way of dealing with columnwise-only
# quantized tensors for CPU offloading. The complication is due to
# the `rowwise_data` being `None`. The offloading checker incorrectly
# returns `False` and the entire `state` ([None, columnwise_tensor])
# is added to the tensor tag state dict. A better design would change
# how quantized tensors are kept track of in the offload handler.
# Currently at every stage it is ensured that a quantized tensor is a
# list whereas a non-quantized tensor is standalone object, which is
# not good! TODO(@sanandaraj5597)
tensor_offloaded = False
# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
tensor_offloaded = True
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
if is_quantized_tensor:
if tensor_offloaded:
self.tensor_tag_to_state[tensor_tag].append(state)
else:
self.tensor_tag_to_state[tensor_tag].append(tensor_on_device)
else:
self.tensor_tag_to_state[tensor_tag] = state
def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward."""
global CPUOffloadedLayer
# For the first group, kickstart the offload after we have
# the first compute completion
if current_group == 0:
self.d2h_stream.wait_stream(torch.cuda.current_stream())
if not self.double_buffer_created:
# Creating the first copy of double buffer for tensors that are offloaded
for tensor_tag, buf in self.tensor_tag_to_buf.items():
if isinstance(buf, list):
for b in buf:
self.reload_double_buffer[0].append(
torch.empty_like(b) if self.double_buffering else None
)
else:
self.reload_double_buffer[0].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.bulk_offload_group(current_group)
# Window map data structure helps us synchronize based on number
# of layers offloaded
if self.layer_window_map[self.offloaded_group_count] == current_group:
# Stream synchronization both ways
self.d2h_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().wait_stream(self.d2h_stream)
# Time to free the activation memory after usage
for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorStorage class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
# Release the pointer to the tensor
self.tensor_tag_to_buf[tensor_tag] = None
# Time to offload the next group
if self.offloaded_group_count < (self.num_offload_group - 1):
self.bulk_offload_group(self.offloaded_group_count + 1)
# Increment the offload group count to keep track
self.offloaded_group_count += 1
if current_group == (self.num_offload_group - 1):
CPUOffloadedLayer = False
if not self.double_buffer_created:
# Creating second copy of double buffer for tensors that are offloaded
if current_group == (self.num_layers - 1):
for buf in self.reload_double_buffer[0]:
self.reload_double_buffer[1].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.double_buffer_created = True
def on_group_commit_forward(self):
"""This function will cause host device synchronization"""
# handle synchronization events
self.synchronize_on_group_commit_forward(self.current_group)
super().on_group_commit_forward()
def bulk_reload_group(self, group_to_reload):
"""Bulk reload group."""
assert group_to_reload < self.num_offload_group
buffer_idx = 0
double_buffer_idx = group_to_reload % 2
main_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.h2d_stream):
# move back tensors
for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label
if group_id == group_to_reload:
if isinstance(state, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx]
else:
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state[1], device=torch.cuda.current_device()
)
recovered_tensor = SynchronizedGroupOffloadHandler.reload(
state, True, reload_buffer
)
buffer_idx = buffer_idx + 1
self.tensor_tag_to_state[tensor_label] = recovered_tensor
elif isinstance(state, list):
tensor_list = []
for state_tuple in state:
if isinstance(state_tuple, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][
buffer_idx
]
else:
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state_tuple[1], device=torch.cuda.current_device()
)
tensor_list.append(
SynchronizedGroupOffloadHandler.reload(
state_tuple,
True,
reload_buffer,
)
)
buffer_idx = buffer_idx + 1
else:
tensor_list.append(state_tuple)
# No need to write back the duplicated tensor againn
# to the same location, this check ensures that
if tensor_label in self.dereferencing_list:
self.dereferencing_list.remove(tensor_label)
else:
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(
tensor_list
)
if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor):
self.fp8_tensor_object_map[tensor_label]._transpose_invalid = (
self.float8_transpose_cache_valid.pop(tensor_label)
)
self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop(
tensor_label
)
def on_group_commit_backward(self):
# first decrement the current group.
# after last commit in forward, the group will +1; in backward it -1.
# Finally it should be decremented to 0.
self.current_group -= 1
assert self.current_group >= 0
# Layer window data structure helps us to reload at right times
if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group:
# Stream synchronization both ways
self.h2d_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().wait_stream(self.h2d_stream)
# Time to reload the next group
self.bulk_reload_group(self.offloaded_group_count - 1)
# Decrease the offloading group counter
self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0
# Last group computation needs to wait till all the reloads complete
if self.current_group == 0:
torch.cuda.current_stream().wait_stream(self.h2d_stream)
self.offloaded_group_count = 0
def get_cpu_offload_context(
enabled: bool = False,
num_layers: int = 1,
model_layers: int = 1,
offload_activations: bool = True,
offload_weights: bool = False,
double_buffering: bool = False,
):
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
used after every transformer layer. Returns `nullcontext()` if offloading is not enabled.
Usage:
.. code-block:: python
cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True)
with cpu_offload_context:
te_layer.forward(inp_tensor)
cpu_offload_synchronizer()
Parameters
----------
enabled: bool, default = `False`
When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1
Determines the number of transformer layers
you want to offload activations/weights for.
model_layers: int, default = 1
Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True`
When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True`
When set to `True`, offloads the weights for the TE layer.
double_buffering: bool, default = `False`
When set to `True`, uses double buffering for offloading.
"""
if not offload_weights and not offload_activations:
raise ValueError(
"CPU Offloading is enabled while it is not "
"mentioned what to offload (weights/activations)"
)
if offload_weights:
import warnings
warnings.warn(
"Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect.",
DeprecationWarning,
)
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if not offload_activations:
return nullcontext(), lambda x: x
def tensor_need_offloading_checker_activations(tensor):
return hasattr(tensor, "activation_offloading")
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers,
num_model_group=model_layers,
tensor_need_offloading_checker=tensor_need_offloading_checker,
double_buffering=double_buffering,
)
def group_prefetch_offload_commit_async(tensor):
return group_prefetch_offload_commit(tensor, cpu_offload_handler)
if enabled:
return (
CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler),
group_prefetch_offload_commit_async,
)
return nullcontext(), group_prefetch_offload_commit_async
...@@ -41,7 +41,7 @@ from ..cpp_extensions import ( ...@@ -41,7 +41,7 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..quantized_tensor import ( from ..quantized_tensor import (
...@@ -135,6 +135,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -135,6 +135,9 @@ class _GroupedLinear(torch.autograd.Function):
else: else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)
if cpu_offloading:
start_offload(*inputmats)
# Initialize weights # Initialize weights
weights_fp8: list weights_fp8: list
if fp8: if fp8:
...@@ -196,6 +199,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -196,6 +199,9 @@ class _GroupedLinear(torch.autograd.Function):
for i in range(num_gemms): for i in range(num_gemms):
weight_quantizers[i].calibrate(weights[i]) weight_quantizers[i].calibrate(weights[i])
if cpu_offloading:
mark_not_offload(*weights_fp8, *weights)
if is_grad_enabled: if is_grad_enabled:
ctx.weight_quantizers = weight_quantizers ctx.weight_quantizers = weight_quantizers
ctx.weights_shape_1 = weights[0].shape[1] ctx.weights_shape_1 = weights[0].shape[1]
......
...@@ -66,10 +66,15 @@ from ..quantized_tensor import ( ...@@ -66,10 +66,15 @@ from ..quantized_tensor import (
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import ( from ..cpp_extensions import (
general_gemm, general_gemm,
...@@ -158,6 +163,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -158,6 +163,9 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast") nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
if is_cpu_offload_enabled():
start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
weight_requires_grad = weight.requires_grad weight_requires_grad = weight.requires_grad
...@@ -434,8 +442,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -434,8 +442,14 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading: if cpu_offloading:
mark_not_offload(
weightmat,
weight,
bias,
ln_weight,
ln_bias,
)
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will # If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper. # get back torch.Tensor. Torch rips off the Parameter wrapper.
...@@ -542,6 +556,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -542,6 +556,7 @@ class _LayerNormLinear(torch.autograd.Function):
mu, mu,
rsigma, rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors) ) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed # Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors. # by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None ctx.tensor_objects = None
......
...@@ -69,7 +69,12 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer ...@@ -69,7 +69,12 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ..quantized_tensor import ( from ..quantized_tensor import (
QuantizedTensorStorage, QuantizedTensorStorage,
Quantizer, Quantizer,
...@@ -235,6 +240,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -235,6 +240,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype)
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
if is_cpu_offload_enabled():
start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
...@@ -577,6 +584,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -577,6 +584,18 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(act_out) clear_tensor_data(act_out)
act_out = None act_out = None
if cpu_offloading:
mark_not_offload(
ln_weight,
ln_bias,
fc1_weight_final,
fc1_weight,
fc1_bias,
fc2_weight_final,
fc2_weight,
fc2_bias,
)
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
inputmat, inputmat,
ln_weight, ln_weight,
......
...@@ -68,7 +68,12 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize ...@@ -68,7 +68,12 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.utils import is_custom from ..tensor.utils import is_custom
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -229,6 +234,9 @@ class _Linear(torch.autograd.Function): ...@@ -229,6 +234,9 @@ class _Linear(torch.autograd.Function):
else: else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
inputmat_total = inputmat inputmat_total = inputmat
if is_cpu_offload_enabled():
start_offload(inputmat)
nvtx_range_pop(f"{nvtx_label}.input_cast_comm") nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# ------------------------------------------------------ # ------------------------------------------------------
# Input tensor is ready for GEMM... # Input tensor is ready for GEMM...
...@@ -417,6 +425,7 @@ class _Linear(torch.autograd.Function): ...@@ -417,6 +425,7 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module # weights if weights are externally touched outside this module
ctx.weight_object = weight ctx.weight_object = weight
mark_not_offload(weight, weightmat, bias)
# TODO(ksivamani): Check memory usage # TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat, saved_inputmat,
......
...@@ -372,9 +372,9 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -372,9 +372,9 @@ class FusedAdam(torch.optim.Optimizer):
""" """
dtype = self.name_to_dtype_map[state_name] dtype = self.name_to_dtype_map[state_name]
if store_param_remainders: if store_param_remainders:
data = torch.zeros_like(param, dtype=torch.int16) data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
else: else:
data = torch.empty_like(param, dtype=dtype) data = torch.empty(param.shape, dtype=dtype, device=param.device)
if zero_buffer: if zero_buffer:
data.zero_() data.zero_()
......
...@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Iterable, Any, Dict, Union ...@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc import abc
import copy import copy
import warnings import warnings
import math
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -20,6 +21,11 @@ from transformer_engine.pytorch.tensor._quantization_helpers import ( ...@@ -20,6 +21,11 @@ from transformer_engine.pytorch.tensor._quantization_helpers import (
_stride_from_shape, _stride_from_shape,
) )
_quantized_tensor_cpu_supported_ops = (
torch.ops.aten.empty_like.default,
torch.ops.aten.copy_.default,
)
class QuantizedTensorStorage: class QuantizedTensorStorage:
r"""Base class for all *TensorStorage classes. r"""Base class for all *TensorStorage classes.
...@@ -35,7 +41,7 @@ class QuantizedTensorStorage: ...@@ -35,7 +41,7 @@ class QuantizedTensorStorage:
XTensorStorage should contain all data members needed to XTensorStorage should contain all data members needed to
implement the functionality of the tensor, while implement the functionality of the tensor, while
XTensor should only implement the functionality needed XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__).""" to behave like regular torch.Tensor (like __torch_dispatch__)."""
_quantizer: Optional[Quantizer] _quantizer: Optional[Quantizer]
...@@ -63,6 +69,12 @@ class QuantizedTensorStorage: ...@@ -63,6 +69,12 @@ class QuantizedTensorStorage:
f"{self.__class__.__name__} class does not implement update_usage function" f"{self.__class__.__name__} class does not implement update_usage function"
) )
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement get_usages function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
raise NotImplementedError( raise NotImplementedError(
...@@ -128,6 +140,7 @@ def prepare_for_saving( ...@@ -128,6 +140,7 @@ def prepare_for_saving(
t, t_obj = tensor.prepare_for_saving() t, t_obj = tensor.prepare_for_saving()
tensor_list.extend(t) tensor_list.extend(t)
tensor_objects_list.append(t_obj) tensor_objects_list.append(t_obj)
return tensor_list, tensor_objects_list return tensor_list, tensor_objects_list
...@@ -314,6 +327,13 @@ class Quantizer(abc.ABC): ...@@ -314,6 +327,13 @@ class Quantizer(abc.ABC):
"""Returns whether or not given tensor can be quantized""" """Returns whether or not given tensor can be quantized"""
return True return True
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the quantizer"""
return {
"rowwise": self.rowwise_usage,
"columnwise": self.columnwise_usage,
}
class QuantizedTensor(torch.Tensor): class QuantizedTensor(torch.Tensor):
"""Abstract base class for tensor with quantized data """Abstract base class for tensor with quantized data
...@@ -325,7 +345,14 @@ class QuantizedTensor(torch.Tensor): ...@@ -325,7 +345,14 @@ class QuantizedTensor(torch.Tensor):
""" """
def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False): def __new__(
cls,
shape: Iterable[int],
dtype: torch.dtype,
*,
requires_grad: bool = False,
device: Optional[torch.device] = None,
):
# We are assuming only contiguous tensors # We are assuming only contiguous tensors
stride = _stride_from_shape(shape) stride = _stride_from_shape(shape)
instance = torch.Tensor._make_wrapper_subclass( instance = torch.Tensor._make_wrapper_subclass(
...@@ -336,7 +363,7 @@ class QuantizedTensor(torch.Tensor): ...@@ -336,7 +363,7 @@ class QuantizedTensor(torch.Tensor):
dtype=dtype, dtype=dtype,
layout=torch.strided, layout=torch.strided,
requires_grad=requires_grad, requires_grad=requires_grad,
device=torch.cuda.current_device(), device=torch.cuda.current_device() if device is None else device,
) )
return instance return instance
...@@ -366,6 +393,9 @@ class QuantizedTensor(torch.Tensor): ...@@ -366,6 +393,9 @@ class QuantizedTensor(torch.Tensor):
def clear(self): def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully""" """Deallocate this tensor's memory. Typically not needed and must be used carefully"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement clear function"
)
def __repr__(self, *, tensor_contents=None) -> str: def __repr__(self, *, tensor_contents=None) -> str:
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
...@@ -407,6 +437,26 @@ class QuantizedTensor(torch.Tensor): ...@@ -407,6 +437,26 @@ class QuantizedTensor(torch.Tensor):
if func == torch.ops.aten.copy_.default: if func == torch.ops.aten.copy_.default:
dst = args[0] dst = args[0]
src = args[1] src = args[1]
if (
isinstance(dst, QuantizedTensor)
and isinstance(src, QuantizedTensor)
and type(dst._quantizer) is type(src._quantizer)
and set(src.get_usages().keys()) == set(dst.get_usages().keys())
and all(
src.get_usages()[usage] == dst.get_usages()[usage]
for usage in src.get_usages().keys()
)
):
dst_tensors, dst_tensor_obj = dst.prepare_for_saving()
src_tensors, src_tensor_obj = src.prepare_for_saving()
for dst_tensor, src_tensor in zip(dst_tensors, src_tensors):
if dst_tensor is not None:
dst_tensor.copy_(src_tensor, *args[2:], **kwargs)
dst_tensor_obj.restore_from_saved(dst_tensors)
src_tensor_obj.restore_from_saved(src_tensors)
return None
if isinstance(dst, QuantizedTensor): if isinstance(dst, QuantizedTensor):
dst.quantize_(src) dst.quantize_(src)
else: else:
...@@ -419,6 +469,36 @@ class QuantizedTensor(torch.Tensor): ...@@ -419,6 +469,36 @@ class QuantizedTensor(torch.Tensor):
if func == torch.ops.aten.view.default: if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views") raise NotImplementedError("{cls.__name__} class does not support tensor views")
# Empty like op
if func == torch.ops.aten.empty_like.default:
tensor = args[0]
device = kwargs.get("device", tensor.device)
requires_grad = kwargs.get("requires_grad", tensor.requires_grad)
pin_memory = kwargs.get("pin_memory", False)
usage = tensor.get_usages()
quantizer_usage = tensor._quantizer.get_usages()
tensor._quantizer.set_usage(**usage)
out = tensor._quantizer.make_empty(
shape=tensor.shape,
dtype=tensor.dtype,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory,
)
tensor._quantizer.set_usage(**quantizer_usage)
return out
if func == torch.ops.aten.numel.default:
tensor = args[0]
return math.prod(tensor.size())
if func == torch.ops.aten.is_pinned.default:
tensor = args[0]
for t in tensor.get_data_tensors():
if t is not None:
return func(t)
return False # Or error out?
def maybe_unwrap(arg): def maybe_unwrap(arg):
if isinstance(arg, QuantizedTensor): if isinstance(arg, QuantizedTensor):
return arg.dequantize(dtype=arg.dtype) return arg.dequantize(dtype=arg.dtype)
...@@ -463,6 +543,16 @@ class QuantizedTensor(torch.Tensor): ...@@ -463,6 +543,16 @@ class QuantizedTensor(torch.Tensor):
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
def check_if_cpu(arg):
if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu":
assert (
func in _quantized_tensor_cpu_supported_ops
), f"QuantizedTensor on CPU does not support this operation: {func}"
return arg
args = tree_map(check_if_cpu, args)
# Do not force the QuantizedTensor type on the returned tensor # Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs) return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
......
...@@ -214,6 +214,7 @@ class Float8BlockQuantizer(Quantizer): ...@@ -214,6 +214,7 @@ class Float8BlockQuantizer(Quantizer):
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = False, requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data""" """Construct quantized tensor with uninitialized data"""
if device is None: if device is None:
...@@ -229,12 +230,13 @@ class Float8BlockQuantizer(Quantizer): ...@@ -229,12 +230,13 @@ class Float8BlockQuantizer(Quantizer):
data = None data = None
scale_inv = None scale_inv = None
if self.rowwise_usage: if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device) data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
scale_shape = self.get_scale_shape(shape, columnwise=False) scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty( scale_inv = torch.empty(
scale_shape, scale_shape,
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
pin_memory=pin_memory,
) )
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
...@@ -242,13 +244,17 @@ class Float8BlockQuantizer(Quantizer): ...@@ -242,13 +244,17 @@ class Float8BlockQuantizer(Quantizer):
columnwise_scale_inv = None columnwise_scale_inv = None
if self.columnwise_usage: if self.columnwise_usage:
columnwise_data = torch.empty( columnwise_data = torch.empty(
self.get_columnwise_shape(shape), dtype=torch.uint8, device=device self.get_columnwise_shape(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
) )
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty( columnwise_scale_inv = torch.empty(
columnwise_scale_shape, columnwise_scale_shape,
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
pin_memory=pin_memory,
) )
# Construct FP8 tensor # Construct FP8 tensor
......
...@@ -101,6 +101,7 @@ class Float8Quantizer(Quantizer): ...@@ -101,6 +101,7 @@ class Float8Quantizer(Quantizer):
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = False, requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8Tensor: ) -> Float8Tensor:
# Canonicalize tensor attributes # Canonicalize tensor attributes
...@@ -108,16 +109,19 @@ class Float8Quantizer(Quantizer): ...@@ -108,16 +109,19 @@ class Float8Quantizer(Quantizer):
device = torch.device("cuda") device = torch.device("cuda")
# Allocate FP8 data # Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device) data = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
data_transpose = None data_transpose = None
if self.columnwise_usage: if self.columnwise_usage:
transpose_shape = [data.size(-1)] + list(data.shape[:-1]) transpose_shape = [shape[-1]] + list(shape[:-1])
data_transpose = torch.empty( data_transpose = torch.empty(
transpose_shape, transpose_shape,
dtype=torch.uint8, dtype=torch.uint8,
device=device, device=device,
pin_memory=pin_memory,
) )
# Construct FP8 tensor # Construct FP8 tensor
...@@ -125,7 +129,7 @@ class Float8Quantizer(Quantizer): ...@@ -125,7 +129,7 @@ class Float8Quantizer(Quantizer):
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
data=data, data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory),
fp8_dtype=self.dtype, fp8_dtype=self.dtype,
requires_grad=requires_grad, requires_grad=requires_grad,
data_transpose=data_transpose, data_transpose=data_transpose,
...@@ -287,6 +291,7 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -287,6 +291,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = False, requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8Tensor: ) -> Float8Tensor:
# Canonicalize tensor attributes # Canonicalize tensor attributes
...@@ -294,23 +299,26 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -294,23 +299,26 @@ class Float8CurrentScalingQuantizer(Quantizer):
device = torch.device("cuda") device = torch.device("cuda")
# Allocate FP8 data # Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device) data = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
data_transpose = None data_transpose = None
if self.columnwise_usage: if self.columnwise_usage:
transpose_shape = [data.size(-1)] + list(data.shape[:-1]) transpose_shape = [shape[-1]] + list(shape[:-1])
data_transpose = torch.empty( data_transpose = torch.empty(
transpose_shape, transpose_shape,
dtype=torch.uint8, dtype=torch.uint8,
device=device, device=device,
pin_memory=pin_memory,
) )
# Construct FP8 tensor # Construct FP8 tensor
return Float8Tensor( return Float8Tensor(
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
data=data, data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory),
fp8_dtype=self.dtype, fp8_dtype=self.dtype,
requires_grad=requires_grad, requires_grad=requires_grad,
data_transpose=data_transpose, data_transpose=data_transpose,
...@@ -715,13 +723,21 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -715,13 +723,21 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
return cls.detach(args[0]) return cls.detach(args[0])
if func == torch.ops.aten.clone.default: if func == torch.ops.aten.clone.default:
return cls.clone(args[0]) return cls.clone(args[0])
if func == torch.ops.aten.copy_.default: if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1] dst, src = args[0], args[1]
# Just copy FP8 attrs if copying between Float8Tensors # Just copy FP8 attrs if copying between Float8Tensors
if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor): if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor):
dst._data.copy_(src._data.detach()) if dst._data is not None:
dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size())) dst._data.copy_(src._data.detach(), *args[2:], **kwargs)
if src._transpose is not None or dst._transpose is not None: if dst._scale_inv is not None:
dst._scale_inv.copy_(
src._scale_inv.view(dst._scale_inv.size()), *args[2:], **kwargs
)
if dst._transpose is not None and not dst._transpose_invalid:
if not src._transpose_invalid:
dst._transpose.copy_(src._transpose, *args[2:], **kwargs)
else:
dst._create_transpose() dst._create_transpose()
return dst return dst
elif func in _ops_to_preserve_subclass_in_fsdp2: elif func in _ops_to_preserve_subclass_in_fsdp2:
......
...@@ -90,6 +90,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -90,6 +90,7 @@ class MXFP8Quantizer(Quantizer):
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = False, requires_grad: bool = False,
pin_memory: bool = False,
) -> MXFP8Tensor: ) -> MXFP8Tensor:
# Canonicalize tensor attributes # Canonicalize tensor attributes
...@@ -105,24 +106,29 @@ class MXFP8Quantizer(Quantizer): ...@@ -105,24 +106,29 @@ class MXFP8Quantizer(Quantizer):
) )
# Allocate FP8 data # Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device) data = None
scale_inv = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
scale_inv = torch.empty( scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8, dtype=torch.uint8,
device=device, device=device,
pin_memory=pin_memory,
) )
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
columnwise_data = None columnwise_data = None
columnwise_scale_inv = None columnwise_scale_inv = None
if self.columnwise_usage: if self.columnwise_usage:
columnwise_data = torch.empty_like(data) columnwise_data = torch.empty_like(data, pin_memory=pin_memory)
columnwise_scale_inv = torch.empty( columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128), round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8, dtype=torch.uint8,
device=device, device=device,
pin_memory=pin_memory,
) )
# Construct FP8 tensor # Construct FP8 tensor
...@@ -348,11 +354,17 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -348,11 +354,17 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
) )
if rowwise_matches and columnwise_matches: if rowwise_matches and columnwise_matches:
if dst._rowwise_data is not None: if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach()) dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs)
dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach()) dst._rowwise_scale_inv.copy_(
src._rowwise_scale_inv.detach(), *args[2:], **kwargs
)
if dst._columnwise_data is not None: if dst._columnwise_data is not None:
dst._columnwise_data.copy_(src._columnwise_data.detach()) dst._columnwise_data.copy_(
dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach()) src._columnwise_data.detach(), *args[2:], **kwargs
)
dst._columnwise_scale_inv.copy_(
src._columnwise_scale_inv.detach(), *args[2:], **kwargs
)
return dst return dst
# FSDP2 related functions. # FSDP2 related functions.
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import math import math
from typing import Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import functools import functools
import torch import torch
...@@ -265,6 +265,7 @@ class NVFP4Quantizer(Quantizer): ...@@ -265,6 +265,7 @@ class NVFP4Quantizer(Quantizer):
*, *,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
pin_memory: bool = False,
requires_grad: bool = False, requires_grad: bool = False,
) -> NVFP4Tensor: ) -> NVFP4Tensor:
...@@ -288,11 +289,18 @@ class NVFP4Quantizer(Quantizer): ...@@ -288,11 +289,18 @@ class NVFP4Quantizer(Quantizer):
scale_inv = None scale_inv = None
amax_rowwise = None amax_rowwise = None
if self.rowwise_usage: if self.rowwise_usage:
data = torch.empty(self.convert_shape_for_fp4(shape), dtype=torch.uint8, device=device) data = torch.empty(
self.convert_shape_for_fp4(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
scale_shape = self.get_scale_shape(shape, columnwise=False) scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(scale_shape, dtype=torch.uint8, device=device) scale_inv = torch.empty(
scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
# Allocate per tensor scale inverse. FP32 format. # Allocate per tensor scale inverse. FP32 format.
amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device) amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
columnwise_data = None columnwise_data = None
...@@ -306,12 +314,15 @@ class NVFP4Quantizer(Quantizer): ...@@ -306,12 +314,15 @@ class NVFP4Quantizer(Quantizer):
self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)), self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)),
dtype=torch.uint8, dtype=torch.uint8,
device=device, device=device,
pin_memory=pin_memory,
) )
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty( columnwise_scale_inv = torch.empty(
columnwise_scale_shape, dtype=torch.uint8, device=device columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
amax_columnwise = torch.zeros(
1, dtype=torch.float32, device=device, pin_memory=pin_memory
) )
amax_columnwise = torch.zeros(1, dtype=torch.float32, device=device)
# Construct FP8 tensor # Construct FP8 tensor
return NVFP4Tensor( return NVFP4Tensor(
...@@ -498,6 +509,12 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -498,6 +509,12 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
return self return self
raise ValueError("NVFP4Tensor does not support different memory formats!") raise ValueError("NVFP4Tensor does not support different memory formats!")
def get_usages(self) -> Dict[str, bool]:
return {
"rowwise": self._rowwise_data is not None,
"columnwise": self._columnwise_data is not None,
}
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
...@@ -520,16 +537,20 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -520,16 +537,20 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
) )
if tensor._rowwise_data is not None: if tensor._rowwise_data is not None:
rowwise_data = data_init_func(tensor._rowwise_data) rowwise_data = data_init_func(tensor._rowwise_data, *args[1:], **kwargs)
rowwise_scale_inv = scale_inv_init_func(tensor._rowwise_scale_inv) rowwise_scale_inv = scale_inv_init_func(
amax_rowwise = torch.zeros_like(tensor._amax_rowwise) tensor._rowwise_scale_inv, *args[1:], **kwargs
)
amax_rowwise = torch.zeros_like(tensor._amax_rowwise, *args[1:], **kwargs)
else: else:
rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None
if tensor._columnwise_data is not None: if tensor._columnwise_data is not None:
columnwise_data = data_init_func(tensor._columnwise_data) columnwise_data = data_init_func(tensor._columnwise_data, *args[1:], **kwargs)
columnwise_scale_inv = scale_inv_init_func(tensor._columnwise_scale_inv) columnwise_scale_inv = scale_inv_init_func(
amax_columnwise = torch.zeros_like(tensor._amax_columnwise) tensor._columnwise_scale_inv, *args[1:], **kwargs
)
amax_columnwise = torch.zeros_like(tensor._amax_columnwise, *args[1:], **kwargs)
else: else:
columnwise_data, columnwise_scale_inv, amax_columnwise = ( columnwise_data, columnwise_scale_inv, amax_columnwise = (
None, None,
......
...@@ -420,3 +420,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -420,3 +420,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
return return
return return
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
return {
"rowwise": self._rowwise_data is not None,
"columnwise": self._columnwise_data is not None,
}
...@@ -225,3 +225,12 @@ class Float8TensorStorage(QuantizedTensorStorage): ...@@ -225,3 +225,12 @@ class Float8TensorStorage(QuantizedTensorStorage):
if not needs_data_transpose: if not needs_data_transpose:
self._transpose = None self._transpose = None
self._transpose_invalid = True self._transpose_invalid = True
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
usages = {"rowwise": self._data is not None}
if is_non_tn_fp8_gemm_supported():
usages["columnwise"] = self._data is not None
else:
usages["columnwise"] = self._transpose is not None and not self._transpose_invalid
return usages
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment