Unverified Commit 11e9d669 authored by Hongbin Liu's avatar Hongbin Liu Committed by GitHub
Browse files

Fix bug when enabling --overlap-grad-reduce in mcore (#2142)



* fix bugs when enabling --overlap-grad-reduce in mcore
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CI
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* format
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent e9a5fa4e
......@@ -1482,7 +1482,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
(wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation:
weight_tensor = noop_cat(self._get_weight_tensors())
if weight_tensor.grad is None:
weight_tensor.grad = wgrad.to(weight_tensor.dtype)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
......
......@@ -452,9 +452,6 @@ class _GroupedLinear(torch.autograd.Function):
else:
wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or (
ctx.wgrad_store is not None
and ctx.wgrad_store.delay_wgrad_compute()
......@@ -829,7 +826,6 @@ class GroupedLinear(TransformerEngineBaseModule):
bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms):
if weight_params[i].grad is None:
weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype)
if self.use_bias:
for i in range(self.num_gemms):
......
......@@ -1197,7 +1197,6 @@ class _LayerNormMLP(torch.autograd.Function):
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm)
fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None
else:
......@@ -2168,9 +2167,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.fc1_bias.grad is None:
self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype)
if not self.fuse_wgrad_accumulation:
if self.fc2_weight.grad is None:
self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype)
if self.fc1_weight.grad is None:
self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype)
del fc2_bias_grad_
del fc2_wgrad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment