Commit 8aca187f authored by wenjh's avatar wenjh
Browse files

Merge branch 'activation_offloading' into 'release_v2.7'

support activation offloading

See merge request dcutoolkit/deeplearing/TransformerEngine!41
parents 79fa3eba cd4cdf80
......@@ -157,6 +157,29 @@ class _BatchLinear(torch.autograd.Function):
for t in saved_inputmats:
if t is not None:
t.activation_offloading = True
offload_activation = False
if hasattr(inp, "offloading_activation"):
offload_activation = True
for i in range(num_gemms):
saved_inputmats[i].offloading_activation = inp.offloading_activation
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading:
raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time."
)
if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation:
grad_added_to_main_grad_list = []
for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
ctx.save_for_backward(
None,
*saved_inputmats,
......@@ -164,7 +187,7 @@ class _BatchLinear(torch.autograd.Function):
*weights,
*weights_fp8,
*[
w.main_grad if cpu_offloading and fuse_wgrad_accumulation else None
w.main_grad if (cpu_offloading or offload_activation) and fuse_wgrad_accumulation else None
for w in weights
],
)
......@@ -203,11 +226,14 @@ class _BatchLinear(torch.autograd.Function):
weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms]
weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms]
main_grads = saved_tensors[4 * ctx.num_gemms :]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], False)
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i]
weights[i] = w
if ctx.offload_activation and weights[i].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT
grad_output = grad_output.contiguous()
grad_output_mats = torch.split(
......
......@@ -211,6 +211,28 @@ class _GroupedLinear(torch.autograd.Function):
if isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True)
offload_activation = False
if hasattr(inp, "offloading_activation"):
offload_activation = True
for i in range(num_gemms):
inputmats[i].offloading_activation = inp.offloading_activation
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading:
raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time."
)
if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation:
grad_added_to_main_grad_list = []
for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
*weights_fp8,
......@@ -273,11 +295,13 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i]
weights[i] = w
weights[i].main_grad = main_grads[i]
if ctx.offload_activation and weights[0].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
# Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment