Unverified Commit 931b44fe authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Fixed convergence issues with CPU offloading (#1026)



* Fixed convergence issues
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_mlp.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/module/linear.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 33a3d02f
......@@ -289,8 +289,6 @@ class _LayerNormLinear(torch.autograd.Function):
if is_grad_enabled:
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8 and weight_fp8 is not None:
weight_fp8.weight_offloading = True
ln_weight.weight_offloading = True
......@@ -411,7 +409,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight = torch.nn.Parameter(weight.requires_grad)
weight.main_grad = main_grad
if ctx.ub_overlap_rs_dgrad:
......
......@@ -425,9 +425,6 @@ class _LayerNormMLP(torch.autograd.Function):
if is_grad_enabled:
if cpu_offloading:
if fuse_wgrad_accumulation:
fc1_weight.main_grad.weight_offloading = True
fc2_weight.main_grad.weight_offloading = True
if fp8 and fc1_weight_fp8 is not None:
fc1_weight_fp8.weight_offloading = True
if fp8 and fc2_weight_fp8 is not None:
......@@ -570,8 +567,8 @@ class _LayerNormMLP(torch.autograd.Function):
)
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
fc1_weight = Parameter(fc1_weight, False)
fc2_weight = Parameter(fc2_weight, False)
fc1_weight = Parameter(fc1_weight.requires_grad)
fc2_weight = Parameter(fc2_weight.requires_grad)
fc1_weight.main_grad = fc1_weight_main_grad
fc2_weight.main_grad = fc2_weight_main_grad
......
......@@ -310,8 +310,6 @@ class _Linear(torch.autograd.Function):
saved_inputmat = inputmat_no_fp8
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8 and weight_fp8 is not None:
weight_fp8.weight_offloading = True
weight.weight_offloading = True
......@@ -403,7 +401,7 @@ class _Linear(torch.autograd.Function):
)
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight = torch.nn.Parameter(weight.requires_grad)
weight.main_grad = main_grad
tp_world_size = get_distributed_world_size(ctx.tp_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment