Unverified Commit 677de00d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Numerics fix from #40 (#59)



Stats and dgrad bug fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f985db38
......@@ -417,11 +417,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()
# Either we're in FP8 training or calibration for FP8 inference
needs_stats = (self.training if self.fp8 else self.fp8_calibration)
if needs_stats:
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax:
......@@ -434,6 +431,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
......@@ -1395,7 +1394,6 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.requires_dgrad = inputmat.requires_grad
ctx.requires_wgrad = weight.requires_grad
# Row Parallel Linear
......@@ -1461,9 +1459,7 @@ class _Linear(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False
)
if ctx.requires_dgrad:
# DGRAD
if ctx.fp8:
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
......@@ -1555,7 +1551,7 @@ class _Linear(torch.autograd.Function):
wgrad if ctx.requires_wgrad else None,
None,
None,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgrad.view(ctx.inp_shape),
grad_bias,
None,
None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment