Unverified Commit 6a855962 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Distopt with offload (#1573)



* DistOpt support with offloading
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* Added distopt support for TE2.0
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* Restricted this to MCore DistOpt only
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* Added guards
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/module/linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7ddc5932
......@@ -383,6 +383,17 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_object = weight
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
weightmat,
......@@ -526,8 +537,11 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad
if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
origin_weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad
ctx.ub_obj_gradout = None
ub_obj_dgrad = None
......
......@@ -291,6 +291,17 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_object = weight
# TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat,
......@@ -392,8 +403,10 @@ class _Linear(torch.autograd.Function):
else None
)
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, weight.requires_grad)
if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad
# Gather intermediate/activation tensors if needed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment