Commit 162e32d4 authored by dongcl's avatar dongcl
Browse files

support activation offloading

parent 8aca187f
......@@ -298,8 +298,8 @@ class _GroupedLinear(torch.autograd.Function):
if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i]
weights[i] = w
weights[i].main_grad = main_grads[i]
if ctx.offload_activation and weights[0].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
......
......@@ -38,6 +38,7 @@ from ..utils import (
nvtx_range_push,
requires_grad,
needs_quantized_gemm,
get_activation_offloading,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -431,10 +432,33 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
offload_activation = False
if get_activation_offloading():
offload_activation = True
if not inputmat.is_contiguous():
inputmat = inputmat.contiguous()
inputmat.offloading_activation = True
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading:
raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time."
)
if offload_activation and weight.requires_grad and fuse_wgrad_accumulation:
if hasattr(weight, "grad_added_to_main_grad"):
ctx.has_grad_added_to_main_grad = True
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
weight.grad_added_to_main_grad = True
ctx.weight_object = weight
else:
ctx.has_grad_added_to_main_grad = False
if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
if ctx.has_grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
......@@ -567,9 +591,11 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.grad_added_to_main_grad:
if ctx.cpu_offloading or ctx.offload_activation or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.has_grad_added_to_main_grad:
origin_weight = ctx.weight_object
if ctx.offload_activation:
origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad
......@@ -949,7 +975,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
dgrad = dgrad.reshape(inputmat.size())
elif ctx.normalization == "RMSNorm":
if enable_lightop and (rsigma.dtype is torch.bfloat16 or rsigma.dtype is torch.float16):
if enable_lightop and (rsigma is torch.bfloat16 or rsigma is torch.float16):
dgrad, dgamma =rmsnorm_backward(dgrad,inputmat,rsigma,ln_weight)
else:
......@@ -1546,6 +1572,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
else:
fwd_fn = _LayerNormLinear.forward
args = [None]
args += (
inp,
self.layer_norm_weight,
......
......@@ -38,6 +38,7 @@ from ..utils import (
assert_dim_for_fp8_exec,
nvtx_range_pop,
nvtx_range_push,
get_activation_offloading,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -396,10 +397,30 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
offload_activation = False
if get_activation_offloading():
offload_activation = True
if not saved_inputmat.is_contiguous():
saved_inputmat = saved_inputmat.contiguous()
saved_inputmat.offloading_activation = True
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading:
raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time."
)
if offload_activation and weight.requires_grad and fuse_wgrad_accumulation:
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.has_grad_added_to_main_grad:
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
ctx.weight_object = weight
weight.grad_added_to_main_grad = True
if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
if ctx.has_grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
......@@ -482,7 +503,6 @@ class _Linear(torch.autograd.Function):
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors)
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
......@@ -494,9 +514,11 @@ class _Linear(torch.autograd.Function):
else None
)
if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.grad_added_to_main_grad:
if ctx.cpu_offloading or ctx.offload_activation or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.has_grad_added_to_main_grad:
weight = ctx.weight_object
if ctx.offload_activation:
weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad
......
......@@ -14,6 +14,34 @@ import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version
from torch.utils.cpp_extension import IS_HIP_EXTENSION
ActivationOffloadEnabled = False
def get_activation_offloading():
global ActivationOffloadEnabled
return ActivationOffloadEnabled
def set_activation_offloading(activation_offloading):
global ActivationOffloadEnabled
ActivationOffloadEnabled = activation_offloading
class ActivationOffloadContextManager:
"""A reusable context manager for switch ActivationOffloadEnabled"""
def __init__(self, activation_offloading):
self.activation_offloading = activation_offloading
def __enter__(self):
self.origin_cpu_offloading = get_activation_offloading()
set_activation_offloading(self.activation_offloading)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
set_activation_offloading(self.origin_cpu_offloading)
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""Check if any of the given tensors require gradient."""
for tensor in tensors:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment