Unverified Commit f985db38 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Bug fixes from #40 (#58)



* Fix conflict from #40
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7fc079a4
......@@ -967,6 +967,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1377,14 +1378,10 @@ class _Linear(torch.autograd.Function):
if is_training:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad
else None,
inputmat_t if weight.requires_grad and fp8_wgrad
else None,
weight if inputmat.requires_grad and not fp8
else None,
weight_t_fp8 if inputmat.requires_grad and fp8
else None,
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
inputmat_t if weight.requires_grad and fp8_wgrad else None,
weight,
weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
......@@ -2389,6 +2386,7 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment