"...git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "088aa19a89dbc7512299b2221923af1195d0950c"
Unverified Commit 677de00d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Numerics fix from #40 (#59)



Stats and dgrad bug fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f985db38
...@@ -417,11 +417,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -417,11 +417,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_init(num_gemms=num_gemms) self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights() self.set_fp8_weights()
# Either we're in FP8 training or calibration for FP8 inference
needs_stats = (self.training if self.fp8 else self.fp8_calibration)
if needs_stats:
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
# Previous iteration was grad_enabled # Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False): if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
...@@ -434,6 +431,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -434,6 +431,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
amax_and_scale_update( amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
) )
if self.fp8 and self.training:
# Setup for amax reduction # Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module() self.fp8_meta["first_module"] = is_first_fp8_module()
...@@ -1395,7 +1394,6 @@ class _Linear(torch.autograd.Function): ...@@ -1395,7 +1394,6 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.requires_dgrad = inputmat.requires_grad
ctx.requires_wgrad = weight.requires_grad ctx.requires_wgrad = weight.requires_grad
# Row Parallel Linear # Row Parallel Linear
...@@ -1461,9 +1459,7 @@ class _Linear(torch.autograd.Function): ...@@ -1461,9 +1459,7 @@ class _Linear(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
if ctx.requires_dgrad:
# DGRAD # DGRAD
if ctx.fp8:
dgrad = fp8_gemm( dgrad = fp8_gemm(
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
...@@ -1555,7 +1551,7 @@ class _Linear(torch.autograd.Function): ...@@ -1555,7 +1551,7 @@ class _Linear(torch.autograd.Function):
wgrad if ctx.requires_wgrad else None, wgrad if ctx.requires_wgrad else None,
None, None,
None, None,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape),
grad_bias, grad_bias,
None, None,
None, None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment