Unverified Commit 931b44fe authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Fixed convergence issues with CPU offloading (#1026)



* Fixed convergence issues
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_mlp.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/module/linear.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 33a3d02f
...@@ -289,8 +289,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -289,8 +289,6 @@ class _LayerNormLinear(torch.autograd.Function):
if is_grad_enabled: if is_grad_enabled:
if cpu_offloading: if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8 and weight_fp8 is not None: if fp8 and weight_fp8 is not None:
weight_fp8.weight_offloading = True weight_fp8.weight_offloading = True
ln_weight.weight_offloading = True ln_weight.weight_offloading = True
...@@ -411,7 +409,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -411,7 +409,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False) weight = torch.nn.Parameter(weight.requires_grad)
weight.main_grad = main_grad weight.main_grad = main_grad
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
......
...@@ -425,9 +425,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -425,9 +425,6 @@ class _LayerNormMLP(torch.autograd.Function):
if is_grad_enabled: if is_grad_enabled:
if cpu_offloading: if cpu_offloading:
if fuse_wgrad_accumulation:
fc1_weight.main_grad.weight_offloading = True
fc2_weight.main_grad.weight_offloading = True
if fp8 and fc1_weight_fp8 is not None: if fp8 and fc1_weight_fp8 is not None:
fc1_weight_fp8.weight_offloading = True fc1_weight_fp8.weight_offloading = True
if fp8 and fc2_weight_fp8 is not None: if fp8 and fc2_weight_fp8 is not None:
...@@ -570,8 +567,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -570,8 +567,8 @@ class _LayerNormMLP(torch.autograd.Function):
) )
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
fc1_weight = Parameter(fc1_weight, False) fc1_weight = Parameter(fc1_weight.requires_grad)
fc2_weight = Parameter(fc2_weight, False) fc2_weight = Parameter(fc2_weight.requires_grad)
fc1_weight.main_grad = fc1_weight_main_grad fc1_weight.main_grad = fc1_weight_main_grad
fc2_weight.main_grad = fc2_weight_main_grad fc2_weight.main_grad = fc2_weight_main_grad
......
...@@ -310,8 +310,6 @@ class _Linear(torch.autograd.Function): ...@@ -310,8 +310,6 @@ class _Linear(torch.autograd.Function):
saved_inputmat = inputmat_no_fp8 saved_inputmat = inputmat_no_fp8
if cpu_offloading: if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8 and weight_fp8 is not None: if fp8 and weight_fp8 is not None:
weight_fp8.weight_offloading = True weight_fp8.weight_offloading = True
weight.weight_offloading = True weight.weight_offloading = True
...@@ -403,7 +401,7 @@ class _Linear(torch.autograd.Function): ...@@ -403,7 +401,7 @@ class _Linear(torch.autograd.Function):
) )
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False) weight = torch.nn.Parameter(weight.requires_grad)
weight.main_grad = main_grad weight.main_grad = main_grad
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment