Unverified Commit 59130cc9 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Support activation CPU offloading in fusible ops (#2158)



* Add CPU offloading logic to ops. Fix test to compute dgrad.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure grads are contiguous in op backwards
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add op-based MLP to CPU offloading tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Handle different weight cache behavior on Hopper/Blackwell

Add MXFP8 to CPU offload tests.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove MXFP8 test
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cd2034f3
......@@ -2,8 +2,11 @@
#
# See LICENSE for license information.
import contextlib
import gc
import os
from contextlib import nullcontext
from typing import Iterable, Optional
import pytest
import torch
......@@ -11,15 +14,16 @@ import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends
# Check if FP8 is supported
# Check supported quantization schemes
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_recipes = [None]
quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
......@@ -48,85 +52,139 @@ model_types = {
"transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
),
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
"layernorm_mlp_ops": lambda: te.ops.Sequential(
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
),
}
def _get_input():
return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda()
def _make_input() -> torch.Tensor:
"""Generate random input tensor."""
return torch.randn(
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
def _get_fp8_weight_cache_size(models, fp8_recipe):
"""
Calculate the total FP8 weight cache size (in MB) for a list of models.
"""
if fp8_recipe is None:
def _warmup_model(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> None:
"""Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.fp8_autocast(
enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe,
):
tensor = module(tensor)
tensor.sum().backward()
def _estimate_cached_weight_size(
model_name: str,
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> float:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if quantization_recipe is None:
return 0
params_bytes = 0
for model in models:
for name, param in model.named_parameters():
if "weight" in name:
params_bytes += param.numel()
# Count number of weight param elements
param_elements = 0
for module in modules:
for param in module.parameters():
if param.dim() == 2:
param_elements += param.numel()
# FP8 tensor-scaling caches one byte per element
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op",
"layernorm_mlp_ops",
):
# Modules do not deallocate FP8 transpose for weights
return 2 * param_elements / 1024**2
return param_elements / 1024**2
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if quantization_recipe.mxfp8():
if model_name not in ("linear_op", "layernorm_mlp_ops"):
# Modules do not deallocate column-wise MXFP8 data for weights
return 2 * param_elements * (1 + 1 / 32) / 1024**2
return param_elements * (1 + 1 / 32) / 1024**2
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")
def _measure_cached_memory(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
cpu_offload: bool,
) -> float:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
Memory measurement excludes the input and output tensors.
# One byte for columnwise and one byte for rowwise,
# hence multiply by 2 and convert to MB
# there is 1 byte of scale per 32 elements in mxFP8
factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1
return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2)
"""
# Reset memory
gc.collect()
torch.cuda.empty_cache()
def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload):
tensor = _get_input()
# Context and sync function for CPU offloading
if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(
enabled=True,
num_layers=len(models) - 1,
model_layers=len(models),
num_layers=len(modules),
model_layers=len(modules) + 1,
offload_activations=True,
offload_weights=False,
)
else:
offload_context = nullcontext()
offload_context = contextlib.nullcontext()
sync_function = lambda x: x
for model in models:
# Forward pass, with dummy step to trigger offload for last module
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe
enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe
), offload_context:
tensor = model(tensor)
tensor = module(tensor)
tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
max_mem_used = torch.cuda.memory_allocated() / (1024**2)
torch.cuda.synchronize()
# Backward pass
tensor.sum().backward()
torch.cuda.synchronize()
return max_mem_used
# Memory usage in MiB
return memory_after_forward - memory_before_forward
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model_key", model_types.keys())
def test_cpu_offload(fp8_recipe, model_key) -> None:
"""
We run three configurations:
(1) No offloading: All activations remain on the GPU between forward and backward passes.
(2) No offloading (one layer): Only the first layer's activations remain on the GPU between
forward and backward passes.
(3) With offloading (all layers): Only the last layer's activations remain on the GPU
between forward and backward passes, while all other layers are offloaded to the CPU.
We expect the memory consumption of configurations (2) and (3) to be similar, with
the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""
import gc
@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
"""Check that CPU offloading runs and has expected memory usage."""
gc.collect()
model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)]
if model_key in ["multihead_attention", "transformer_layer"]:
# Construct model
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
if model_name in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
......@@ -138,20 +196,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
without_offloading = _measure_memory_between_forward_and_backward(
models_list, fp8_recipe, False
)
without_offloading_one_layer = _measure_memory_between_forward_and_backward(
models_list[:1], fp8_recipe, False
)
with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True)
# Warmup
_warmup_model(modules_list, quantization_recipe)
assert with_offloading < without_offloading
# Measure cached memory after forward pass
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
# The only difference between the memory consumption of with_offloading
# and without_offloading_one_layer should be the size of the FP8 weights cache,
# which is not offloaded to the CPU.
memory_consumption_diff = abs(with_offloading - without_offloading_one_layer)
assert (
memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON
# Check for expected memory usage
assert memory_with_offload < memory_without_offload
memory_from_cached_weights = _estimate_cached_weight_size(
model_name,
modules_list,
quantization_recipe,
)
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
......@@ -29,7 +29,9 @@ def maybe_dequantize(
if is_quantized_tensor(tensor):
return tensor.dequantize(dtype=dtype)
if dtype is not None and tensor.dtype != dtype:
return tensor.to(dtype)
tensor = tensor.to(dtype)
if not tensor.is_contiguous():
tensor = tensor.contiguous()
return tensor
......
......@@ -11,6 +11,7 @@ from typing import Optional
import torch
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
......@@ -110,6 +111,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
......
......@@ -13,6 +13,7 @@ from typing import Any, Optional
import torch
from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import (
CudaRNGStatesTracker,
gather_along_first_dim,
......@@ -964,6 +965,8 @@ class BasicLinear(BasicOperation):
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer
......
......@@ -9,6 +9,7 @@ from typing import Optional
import torch
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from .._common import maybe_autocast_dtype, maybe_dequantize
......@@ -70,6 +71,8 @@ class Dropout(BasicOperation):
# Save context for backward
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(mask)
ctx.save_for_backward(mask)
ctx.impl = impl
ctx.dropout_probability = self.dropout_probability
......
......@@ -10,10 +10,8 @@ import os
import torch
from ...utils import clear_tensor_data
from ... import torch_version
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...jit import (
l2normalization_fused,
l2normalization_fwd_fused,
......@@ -22,6 +20,9 @@ from ...jit import (
warmup_jit_l2normalization_all_dtypes,
)
from ...tensor import Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize
class L2Normalization(BasicOperation):
......@@ -101,6 +102,8 @@ class L2Normalization(BasicOperation):
# Save state for backward pass
if requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rsqrt_norm)
ctx.save_for_backward(x, rsqrt_norm)
return y
......
......@@ -14,6 +14,9 @@ import torch
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
canonicalize_device,
canonicalize_dtype,
......@@ -22,8 +25,6 @@ from ...utils import (
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
class LayerNorm(BasicOperation):
......@@ -215,6 +216,8 @@ class LayerNorm(BasicOperation):
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, means, rstdevs)
ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype
......
......@@ -14,6 +14,9 @@ import torch
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
canonicalize_device,
canonicalize_dtype,
......@@ -22,8 +25,6 @@ from ...utils import (
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
class RMSNorm(BasicOperation):
......@@ -196,6 +197,8 @@ class RMSNorm(BasicOperation):
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rstdevs)
ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype
......
......@@ -10,14 +10,11 @@ from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import BasicLinear, Bias
from ..op import FusedOperation, FusibleOperation, OperationContext
class ForwardLinearBiasActivation(FusedOperation):
......@@ -121,6 +118,8 @@ class ForwardLinearBiasActivation(FusedOperation):
# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
......
......@@ -10,14 +10,11 @@ from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from transformer_engine.pytorch.tensor import Quantizer
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, Bias
from ..op import FusedOperation, FusibleOperation, OperationContext
class ForwardLinearBiasAdd(FusedOperation):
......@@ -118,6 +115,8 @@ class ForwardLinearBiasAdd(FusedOperation):
# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
......
......@@ -10,14 +10,15 @@ from typing import Any, Optional
import torch
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, ConstantScale
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...tensor import Quantizer
class ForwardLinearScaleAdd(FusedOperation):
......@@ -95,6 +96,8 @@ class ForwardLinearScaleAdd(FusedOperation):
# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
......
......@@ -12,6 +12,7 @@ import torch
from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import get_distributed_world_size
from ...fp8 import FP8GlobalStateManager
from ...module.base import (
......@@ -353,6 +354,8 @@ class UserbuffersForwardLinear(FusedOperation):
# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment