Unverified Commit 6a855962 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Distopt with offload (#1573)



* DistOpt support with offloading
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* Added distopt support for TE2.0
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* Restricted this to MCore DistOpt only
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* Added guards
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/module/linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7ddc5932
...@@ -383,6 +383,17 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -383,6 +383,17 @@ class _LayerNormLinear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_object = weight
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
inputmat, inputmat,
weightmat, weightmat,
...@@ -526,8 +537,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -526,8 +537,11 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one. # we need to connect them into one.
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading:
weight.main_grad = main_grad if ctx.grad_added_to_main_grad:
origin_weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad
ctx.ub_obj_gradout = None ctx.ub_obj_gradout = None
ub_obj_dgrad = None ub_obj_dgrad = None
......
...@@ -291,6 +291,17 @@ class _Linear(torch.autograd.Function): ...@@ -291,6 +291,17 @@ class _Linear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_object = weight
# TODO(ksivamani): Check memory usage # TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat, saved_inputmat,
...@@ -392,9 +403,11 @@ class _Linear(torch.autograd.Function): ...@@ -392,9 +403,11 @@ class _Linear(torch.autograd.Function):
else None else None
) )
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading:
weight = torch.nn.Parameter(weight, weight.requires_grad) if ctx.grad_added_to_main_grad:
weight.main_grad = main_grad weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad
# Gather intermediate/activation tensors if needed # Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment