Unverified Commit 59130cc9 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Support activation CPU offloading in fusible ops (#2158)



* Add CPU offloading logic to ops. Fix test to compute dgrad.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure grads are contiguous in op backwards
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add op-based MLP to CPU offloading tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Handle different weight cache behavior on Hopper/Blackwell

Add MXFP8 to CPU offload tests.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove MXFP8 test
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cd2034f3
...@@ -2,8 +2,11 @@ ...@@ -2,8 +2,11 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import contextlib
import gc
import os import os
from contextlib import nullcontext from typing import Iterable, Optional
import pytest import pytest
import torch import torch
...@@ -11,15 +14,16 @@ import transformer_engine.pytorch as te ...@@ -11,15 +14,16 @@ import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends from utils import ModelConfig, get_available_attention_backends
# Check if FP8 is supported # Check supported quantization schemes
fp8_available, _ = FP8GlobalStateManager.is_fp8_available() fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_recipes = [None] quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available: if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling()) quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
fp8_recipes.append(recipe.DelayedScaling())
model_config = { model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
...@@ -48,85 +52,139 @@ model_types = { ...@@ -48,85 +52,139 @@ model_types = {
"transformer_layer": lambda: te.TransformerLayer( "transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
), ),
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
"layernorm_mlp_ops": lambda: te.ops.Sequential(
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
),
} }
def _get_input(): def _make_input() -> torch.Tensor:
return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda() """Generate random input tensor."""
return torch.randn(
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
def _get_fp8_weight_cache_size(models, fp8_recipe): def _warmup_model(
""" modules: Iterable[torch.nn.Module],
Calculate the total FP8 weight cache size (in MB) for a list of models. quantization_recipe: Optional[recipe.Recipe],
""" ) -> None:
if fp8_recipe is None: """Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.fp8_autocast(
enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe,
):
tensor = module(tensor)
tensor.sum().backward()
def _estimate_cached_weight_size(
model_name: str,
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> float:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if quantization_recipe is None:
return 0 return 0
params_bytes = 0 # Count number of weight param elements
for model in models: param_elements = 0
for name, param in model.named_parameters(): for module in modules:
if "weight" in name: for param in module.parameters():
params_bytes += param.numel() if param.dim() == 2:
param_elements += param.numel()
# FP8 tensor-scaling caches one byte per element
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op",
"layernorm_mlp_ops",
):
# Modules do not deallocate FP8 transpose for weights
return 2 * param_elements / 1024**2
return param_elements / 1024**2
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if quantization_recipe.mxfp8():
if model_name not in ("linear_op", "layernorm_mlp_ops"):
# Modules do not deallocate column-wise MXFP8 data for weights
return 2 * param_elements * (1 + 1 / 32) / 1024**2
return param_elements * (1 + 1 / 32) / 1024**2
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")
def _measure_cached_memory(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
cpu_offload: bool,
) -> float:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
Memory measurement excludes the input and output tensors.
# One byte for columnwise and one byte for rowwise, """
# hence multiply by 2 and convert to MB
# there is 1 byte of scale per 32 elements in mxFP8
factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1
return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2)
# Reset memory
gc.collect()
torch.cuda.empty_cache()
def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload): # Context and sync function for CPU offloading
tensor = _get_input()
if cpu_offload: if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context( offload_context, sync_function = te.get_cpu_offload_context(
enabled=True, enabled=True,
num_layers=len(models) - 1, num_layers=len(modules),
model_layers=len(models), model_layers=len(modules) + 1,
offload_activations=True, offload_activations=True,
offload_weights=False, offload_weights=False,
) )
else: else:
offload_context = nullcontext() offload_context = contextlib.nullcontext()
sync_function = lambda x: x sync_function = lambda x: x
for model in models: # Forward pass, with dummy step to trigger offload for last module
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.fp8_autocast( with te.fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe
), offload_context: ), offload_context:
tensor = model(tensor) tensor = module(tensor)
tensor = sync_function(tensor) tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
max_mem_used = torch.cuda.memory_allocated() / (1024**2) # Backward pass
torch.cuda.synchronize()
tensor.sum().backward() tensor.sum().backward()
torch.cuda.synchronize()
return max_mem_used # Memory usage in MiB
return memory_after_forward - memory_before_forward
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_key", model_types.keys()) @pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(fp8_recipe, model_key) -> None: def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
""" """Check that CPU offloading runs and has expected memory usage."""
We run three configurations:
(1) No offloading: All activations remain on the GPU between forward and backward passes.
(2) No offloading (one layer): Only the first layer's activations remain on the GPU between
forward and backward passes.
(3) With offloading (all layers): Only the last layer's activations remain on the GPU
between forward and backward passes, while all other layers are offloaded to the CPU.
We expect the memory consumption of configurations (2) and (3) to be similar, with
the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""
import gc
gc.collect() # Construct model
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
model_cls = model_types[model_key] if model_name in ["multihead_attention", "transformer_layer"]:
models_list = [model_cls() for _ in range(NUM_LAYERS)]
if model_key in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends( available_backends, *_ = get_available_attention_backends(
model_config["small"], model_config["small"],
qkv_dtype=torch.bfloat16, qkv_dtype=torch.bfloat16,
...@@ -138,20 +196,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: ...@@ -138,20 +196,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
without_offloading = _measure_memory_between_forward_and_backward( # Warmup
models_list, fp8_recipe, False _warmup_model(modules_list, quantization_recipe)
)
without_offloading_one_layer = _measure_memory_between_forward_and_backward(
models_list[:1], fp8_recipe, False
)
with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True)
assert with_offloading < without_offloading # Measure cached memory after forward pass
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
# The only difference between the memory consumption of with_offloading # Check for expected memory usage
# and without_offloading_one_layer should be the size of the FP8 weights cache, assert memory_with_offload < memory_without_offload
# which is not offloaded to the CPU. memory_from_cached_weights = _estimate_cached_weight_size(
memory_consumption_diff = abs(with_offloading - without_offloading_one_layer) model_name,
assert ( modules_list,
memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON quantization_recipe,
) )
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
...@@ -29,7 +29,9 @@ def maybe_dequantize( ...@@ -29,7 +29,9 @@ def maybe_dequantize(
if is_quantized_tensor(tensor): if is_quantized_tensor(tensor):
return tensor.dequantize(dtype=dtype) return tensor.dequantize(dtype=dtype)
if dtype is not None and tensor.dtype != dtype: if dtype is not None and tensor.dtype != dtype:
return tensor.to(dtype) tensor = tensor.to(dtype)
if not tensor.is_contiguous():
tensor = tensor.contiguous()
return tensor return tensor
......
...@@ -11,6 +11,7 @@ from typing import Optional ...@@ -11,6 +11,7 @@ from typing import Optional
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
...@@ -110,6 +111,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -110,6 +111,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Save state for backward pass # Save state for backward pass
if ctx.requires_grad: if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
ctx.save_for_backward(x) ctx.save_for_backward(x)
ctx.dtype = dtype ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
......
...@@ -13,6 +13,7 @@ from typing import Any, Optional ...@@ -13,6 +13,7 @@ from typing import Any, Optional
import torch import torch
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import ( from ...distributed import (
CudaRNGStatesTracker, CudaRNGStatesTracker,
gather_along_first_dim, gather_along_first_dim,
...@@ -964,6 +965,8 @@ class BasicLinear(BasicOperation): ...@@ -964,6 +965,8 @@ class BasicLinear(BasicOperation):
# Save state for backward pass # Save state for backward pass
if ctx.requires_grad: if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
ctx.save_for_backward(x_local, w) ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
......
...@@ -9,6 +9,7 @@ from typing import Optional ...@@ -9,6 +9,7 @@ from typing import Optional
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from .._common import maybe_autocast_dtype, maybe_dequantize from .._common import maybe_autocast_dtype, maybe_dequantize
...@@ -70,6 +71,8 @@ class Dropout(BasicOperation): ...@@ -70,6 +71,8 @@ class Dropout(BasicOperation):
# Save context for backward # Save context for backward
if ctx.requires_grad: if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(mask)
ctx.save_for_backward(mask) ctx.save_for_backward(mask)
ctx.impl = impl ctx.impl = impl
ctx.dropout_probability = self.dropout_probability ctx.dropout_probability = self.dropout_probability
......
...@@ -10,10 +10,8 @@ import os ...@@ -10,10 +10,8 @@ import os
import torch import torch
from ...utils import clear_tensor_data
from ... import torch_version from ... import torch_version
from .._common import maybe_dequantize from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..op import BasicOperation, OperationContext
from ...jit import ( from ...jit import (
l2normalization_fused, l2normalization_fused,
l2normalization_fwd_fused, l2normalization_fwd_fused,
...@@ -22,6 +20,9 @@ from ...jit import ( ...@@ -22,6 +20,9 @@ from ...jit import (
warmup_jit_l2normalization_all_dtypes, warmup_jit_l2normalization_all_dtypes,
) )
from ...tensor import Quantizer from ...tensor import Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize
class L2Normalization(BasicOperation): class L2Normalization(BasicOperation):
...@@ -101,6 +102,8 @@ class L2Normalization(BasicOperation): ...@@ -101,6 +102,8 @@ class L2Normalization(BasicOperation):
# Save state for backward pass # Save state for backward pass
if requires_grad: if requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rsqrt_norm)
ctx.save_for_backward(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm)
return y return y
......
...@@ -14,6 +14,9 @@ import torch ...@@ -14,6 +14,9 @@ import torch
from transformer_engine_torch import layernorm_bwd, layernorm_fwd from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...constants import TE_DType from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
...@@ -22,8 +25,6 @@ from ...utils import ( ...@@ -22,8 +25,6 @@ from ...utils import (
) )
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
class LayerNorm(BasicOperation): class LayerNorm(BasicOperation):
...@@ -215,6 +216,8 @@ class LayerNorm(BasicOperation): ...@@ -215,6 +216,8 @@ class LayerNorm(BasicOperation):
# Save state for backward pass # Save state for backward pass
if ctx.requires_grad: if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, means, rstdevs)
ctx.save_for_backward(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
......
...@@ -14,6 +14,9 @@ import torch ...@@ -14,6 +14,9 @@ import torch
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...constants import TE_DType from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
...@@ -22,8 +25,6 @@ from ...utils import ( ...@@ -22,8 +25,6 @@ from ...utils import (
) )
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
class RMSNorm(BasicOperation): class RMSNorm(BasicOperation):
...@@ -196,6 +197,8 @@ class RMSNorm(BasicOperation): ...@@ -196,6 +197,8 @@ class RMSNorm(BasicOperation):
# Save state for backward pass # Save state for backward pass
if ctx.requires_grad: if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rstdevs)
ctx.save_for_backward(x, rstdevs) ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
......
...@@ -10,14 +10,11 @@ from typing import Any, Optional ...@@ -10,14 +10,11 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from transformer_engine.pytorch.ops.basic import BasicLinear, Bias from ...fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...tensor import Quantizer from ...tensor import Quantizer
from ..basic import BasicLinear, Bias
from ..op import FusedOperation, FusibleOperation, OperationContext
class ForwardLinearBiasActivation(FusedOperation): class ForwardLinearBiasActivation(FusedOperation):
...@@ -121,6 +118,8 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -121,6 +118,8 @@ class ForwardLinearBiasActivation(FusedOperation):
# Save state for backward pass # Save state for backward pass
if linear_op_ctx.requires_grad: if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
......
...@@ -10,14 +10,11 @@ from typing import Any, Optional ...@@ -10,14 +10,11 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias from ...fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.op import ( from ...tensor import Quantizer
FusedOperation, from ..basic import AddExtraInput, BasicLinear, Bias
FusibleOperation, from ..op import FusedOperation, FusibleOperation, OperationContext
OperationContext,
)
from transformer_engine.pytorch.tensor import Quantizer
class ForwardLinearBiasAdd(FusedOperation): class ForwardLinearBiasAdd(FusedOperation):
...@@ -118,6 +115,8 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -118,6 +115,8 @@ class ForwardLinearBiasAdd(FusedOperation):
# Save state for backward pass # Save state for backward pass
if linear_op_ctx.requires_grad: if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
......
...@@ -10,14 +10,15 @@ from typing import Any, Optional ...@@ -10,14 +10,15 @@ from typing import Any, Optional
import torch import torch
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, ConstantScale from ..basic import AddExtraInput, BasicLinear, ConstantScale
from ..op import ( from ..op import (
FusedOperation, FusedOperation,
FusibleOperation, FusibleOperation,
OperationContext, OperationContext,
) )
from ...tensor import Quantizer
class ForwardLinearScaleAdd(FusedOperation): class ForwardLinearScaleAdd(FusedOperation):
...@@ -95,6 +96,8 @@ class ForwardLinearScaleAdd(FusedOperation): ...@@ -95,6 +96,8 @@ class ForwardLinearScaleAdd(FusedOperation):
# Save state for backward pass # Save state for backward pass
if linear_op_ctx.requires_grad: if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
from transformer_engine_torch import CommOverlapType from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import get_distributed_world_size from ...distributed import get_distributed_world_size
from ...fp8 import FP8GlobalStateManager from ...fp8 import FP8GlobalStateManager
from ...module.base import ( from ...module.base import (
...@@ -353,6 +354,8 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -353,6 +354,8 @@ class UserbuffersForwardLinear(FusedOperation):
# Save state for backward pass # Save state for backward pass
if linear_op_ctx.requires_grad: if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment