Unverified Commit 677de00d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Numerics fix from #40 (#59)



Stats and dgrad bug fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f985db38
...@@ -417,23 +417,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -417,23 +417,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_init(num_gemms=num_gemms) self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights() self.set_fp8_weights()
# Either we're in FP8 training or calibration for FP8 inference update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
needs_stats = (self.training if self.fp8 else self.fp8_calibration)
# Previous iteration was grad_enabled
if needs_stats: if self.fp8_meta.get("update_amax_and_scale_fwd", False):
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch if self.fp8_meta["recipe"].reduce_amax:
# Previous iteration was grad_enabled copy_amax_from_global_buffer(self.fp8_meta, forward=True)
if self.fp8_meta.get("update_amax_and_scale_fwd", False): amax_and_scale_update(
if self.fp8_meta["recipe"].reduce_amax: self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
copy_amax_from_global_buffer(self.fp8_meta, forward=True) )
amax_and_scale_update( set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv else:
) amax_and_scale_update(
set_amax_buffer_key_deletion(self.fp8_meta, forward=True) self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
else: )
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv if self.fp8 and self.training:
)
# Setup for amax reduction # Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module() self.fp8_meta["first_module"] = is_first_fp8_module()
...@@ -446,17 +445,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -446,17 +445,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["autocast_id_fwd"] self.fp8_meta["autocast_id_fwd"]
) )
add_amax_to_global_buffer(self.fp8_meta, forward=True) add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True self.fp8_meta["update_amax_and_scale_fwd"] = True
else: else:
self.fp8_meta["update_amax_and_scale_fwd"] = False self.fp8_meta["update_amax_and_scale_fwd"] = False
# Activation recomputation is used and this is the first forward phase. # Activation recomputation is used and this is the first forward phase.
if ( if (
self.fp8 self.fp8
and is_fp8_activation_recompute_enabled() and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
): ):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
yield inp.contiguous() yield inp.contiguous()
...@@ -1395,7 +1394,6 @@ class _Linear(torch.autograd.Function): ...@@ -1395,7 +1394,6 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.requires_dgrad = inputmat.requires_grad
ctx.requires_wgrad = weight.requires_grad ctx.requires_wgrad = weight.requires_grad
# Row Parallel Linear # Row Parallel Linear
...@@ -1461,32 +1459,30 @@ class _Linear(torch.autograd.Function): ...@@ -1461,32 +1459,30 @@ class _Linear(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
if ctx.requires_dgrad:
# DGRAD # DGRAD
if ctx.fp8: dgrad = fp8_gemm(
dgrad = fp8_gemm( weight_t_fp8,
weight_t_fp8, fwd_scale_inverses,
fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT,
tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward,
fp8_dtype_forward, grad_output_c,
grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv,
ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1,
tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward,
fp8_dtype_backward, ctx.activation_dtype,
ctx.activation_dtype, get_workspace(),
get_workspace(), use_split_accumulator=_2X_ACC_DGRAD,
use_split_accumulator=_2X_ACC_DGRAD, )
) else:
else: # DGRAD
# DGRAD dgrad, _, _ = gemm(
dgrad, _, _ = gemm( weight,
weight, grad_output,
grad_output, ctx.activation_dtype,
ctx.activation_dtype, get_workspace(),
get_workspace(), layout="NN",
layout="NN", grad=True,
grad=True, )
)
# Overlap dgrad-RS/AR with wgrad # Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel: if ctx.parallel_mode == "column" and ctx.sequence_parallel:
...@@ -1555,7 +1551,7 @@ class _Linear(torch.autograd.Function): ...@@ -1555,7 +1551,7 @@ class _Linear(torch.autograd.Function):
wgrad if ctx.requires_wgrad else None, wgrad if ctx.requires_wgrad else None,
None, None,
None, None,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape),
grad_bias, grad_bias,
None, None,
None, None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment