Unverified Commit 452c7374 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Added support for DistOpt with offloading with MoE's (#2264)


Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 88564d59
...@@ -209,6 +209,19 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -209,6 +209,19 @@ class _GroupedLinear(torch.autograd.Function):
if isinstance(weight, QuantizedTensorStorage): if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True) weight.update_usage(columnwise_usage=True)
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_objects = []
for weight in weights:
ctx.weight_objects.append(weight)
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats, *inputmats,
*weights_fp8, *weights_fp8,
...@@ -271,11 +284,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -271,11 +284,15 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N] biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading:
for i in range(ctx.num_gemms): if ctx.grad_added_to_main_grad:
w = torch.nn.Parameter(weights[i], weights[i].requires_grad) for i, weight in enumerate(ctx.weight_objects):
w.main_grad = main_grads[i] origin_weights[i] = ctx.weight_objects[i]
weights[i] = w ctx.weight_objects[i] = None
if ctx.fuse_wgrad_accumulation:
for i in range(N):
origin_weights[i].main_grad = main_grads[i]
# Preprocess grad output # Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment