Commit d262ef4c authored by yuguo's avatar yuguo
Browse files

Merge branch 'main' into 'main'

support activation offloading

See merge request dcutoolkit/deeplearing/TransformerEngine!44
parents b15412aa 6cfcde78
...@@ -157,6 +157,29 @@ class _BatchLinear(torch.autograd.Function): ...@@ -157,6 +157,29 @@ class _BatchLinear(torch.autograd.Function):
for t in saved_inputmats: for t in saved_inputmats:
if t is not None: if t is not None:
t.activation_offloading = True t.activation_offloading = True
offload_activation = False
if hasattr(inp, "offloading_activation"):
offload_activation = True
for i in range(num_gemms):
saved_inputmats[i].offloading_activation = inp.offloading_activation
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading:
raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time."
)
if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation:
grad_added_to_main_grad_list = []
for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
ctx.save_for_backward( ctx.save_for_backward(
None, None,
*saved_inputmats, *saved_inputmats,
...@@ -164,7 +187,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -164,7 +187,7 @@ class _BatchLinear(torch.autograd.Function):
*weights, *weights,
*weights_fp8, *weights_fp8,
*[ *[
w.main_grad if cpu_offloading and fuse_wgrad_accumulation else None w.main_grad if (cpu_offloading or offload_activation) and fuse_wgrad_accumulation else None
for w in weights for w in weights
], ],
) )
...@@ -203,11 +226,14 @@ class _BatchLinear(torch.autograd.Function): ...@@ -203,11 +226,14 @@ class _BatchLinear(torch.autograd.Function):
weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms] weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms]
weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms] weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms]
main_grads = saved_tensors[4 * ctx.num_gemms :] main_grads = saved_tensors[4 * ctx.num_gemms :]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], False) w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
if ctx.offload_activation and weights[i].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
grad_output_mats = torch.split( grad_output_mats = torch.split(
......
...@@ -211,6 +211,28 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -211,6 +211,28 @@ class _GroupedLinear(torch.autograd.Function):
if isinstance(weight, QuantizedTensorBase): if isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True) weight.update_usage(columnwise_usage=True)
offload_activation = False
if hasattr(inp, "offloading_activation"):
offload_activation = True
for i in range(num_gemms):
inputmats[i].offloading_activation = inp.offloading_activation
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading:
raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time."
)
if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation:
grad_added_to_main_grad_list = []
for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats, *inputmats,
*weights_fp8, *weights_fp8,
...@@ -273,11 +295,13 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -273,11 +295,13 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N] biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
if ctx.offload_activation and weights[0].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
# Preprocess grad output # Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
......
...@@ -38,6 +38,7 @@ from ..utils import ( ...@@ -38,6 +38,7 @@ from ..utils import (
nvtx_range_push, nvtx_range_push,
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
get_activation_offloading,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -431,10 +432,33 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -431,10 +432,33 @@ class _LayerNormLinear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
offload_activation = False
if get_activation_offloading():
offload_activation = True
if not inputmat.is_contiguous():
inputmat = inputmat.contiguous()
inputmat.offloading_activation = True
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading:
raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time."
)
if offload_activation and weight.requires_grad and fuse_wgrad_accumulation:
if hasattr(weight, "grad_added_to_main_grad"):
ctx.has_grad_added_to_main_grad = True
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
weight.grad_added_to_main_grad = True
ctx.weight_object = weight
else:
ctx.has_grad_added_to_main_grad = False
if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")): if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad: if ctx.has_grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will # If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper. # get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user # You need to preserve the weight object to have all the attributes user
...@@ -567,9 +591,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -567,9 +591,11 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one. # we need to connect them into one.
if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")): if ctx.cpu_offloading or ctx.offload_activation or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.grad_added_to_main_grad: if ctx.has_grad_added_to_main_grad:
origin_weight = ctx.weight_object origin_weight = ctx.weight_object
if ctx.offload_activation:
origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad origin_weight.main_grad = main_grad
......
...@@ -38,6 +38,7 @@ from ..utils import ( ...@@ -38,6 +38,7 @@ from ..utils import (
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
get_activation_offloading,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -396,10 +397,30 @@ class _Linear(torch.autograd.Function): ...@@ -396,10 +397,30 @@ class _Linear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
offload_activation = False
if get_activation_offloading():
offload_activation = True
if not saved_inputmat.is_contiguous():
saved_inputmat = saved_inputmat.contiguous()
saved_inputmat.offloading_activation = True
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading:
raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time."
)
if offload_activation and weight.requires_grad and fuse_wgrad_accumulation:
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.has_grad_added_to_main_grad:
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
ctx.weight_object = weight
weight.grad_added_to_main_grad = True
if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")): if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad: if ctx.has_grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will # If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper. # get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user # You need to preserve the weight object to have all the attributes user
...@@ -494,9 +515,11 @@ class _Linear(torch.autograd.Function): ...@@ -494,9 +515,11 @@ class _Linear(torch.autograd.Function):
else None else None
) )
if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")): if ctx.cpu_offloading or ctx.offload_activation or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.grad_added_to_main_grad: if ctx.has_grad_added_to_main_grad:
weight = ctx.weight_object weight = ctx.weight_object
if ctx.offload_activation:
weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad weight.main_grad = main_grad
......
...@@ -764,3 +764,30 @@ def make_weak_ref(x): ...@@ -764,3 +764,30 @@ def make_weak_ref(x):
if x is None: if x is None:
return None return None
raise TypeError(f"Invalid type {type(x)} to make weak ref") raise TypeError(f"Invalid type {type(x)} to make weak ref")
ActivationOffloadEnabled = False
def get_activation_offloading():
global ActivationOffloadEnabled
return ActivationOffloadEnabled
def set_activation_offloading(activation_offloading):
global ActivationOffloadEnabled
ActivationOffloadEnabled = activation_offloading
class ActivationOffloadContextManager:
"""A reusable context manager for switch ActivationOffloadEnabled"""
def __init__(self, activation_offloading):
self.activation_offloading = activation_offloading
def __enter__(self):
self.origin_cpu_offloading = get_activation_offloading()
set_activation_offloading(self.activation_offloading)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
set_activation_offloading(self.origin_cpu_offloading)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment