Unverified Commit 2f643ada authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Gradient enablement bug fix (#72)



fix use of PyTorch training flag
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent eda8f461
...@@ -665,7 +665,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -665,7 +665,7 @@ class _LayerNormLinear(torch.autograd.Function):
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
return_layernorm_output: bool, return_layernorm_output: bool,
is_training: bool, is_grad_enabled: bool,
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
...@@ -688,7 +688,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -688,7 +688,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output: if not return_layernorm_output:
if is_training: if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8( ln_out, mu, rsigma = layernorm_fwd_fp8(
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -711,7 +711,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -711,7 +711,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
else: else:
if is_training: if is_grad_enabled:
ln_out_return, mu, rsigma = tex.layernorm_fwd( ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
) )
...@@ -727,7 +727,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -727,7 +727,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
else: else:
if is_training: if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd( ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
) )
...@@ -752,7 +752,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -752,7 +752,7 @@ class _LayerNormLinear(torch.autograd.Function):
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights: if update_fp8_weights:
if is_training: if is_grad_enabled:
fp8_cast_transpose_fused( fp8_cast_transpose_fused(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -806,7 +806,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -806,7 +806,7 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias=use_bias, use_bias=use_bias,
) )
if is_training: if is_grad_enabled:
ctx.save_for_backward( ctx.save_for_backward(
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -1299,17 +1299,17 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1299,17 +1299,17 @@ class LayerNormLinear(TransformerEngineBaseModule):
bias_tensor = ( bias_tensor = (
bias if bias is not None bias if bias is not None
else self.bias if self.parameters_split is None else self.bias if self.parameters_split is None
else self.bias_tensor if not self.training else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names) else self.noop_cat("bias_tensor", self.bias_names)
) )
weight_tensor = ( weight_tensor = (
weight if weight is not None weight if weight is not None
else self.weight if self.parameters_split is None else self.weight if self.parameters_split is None
else self.weight_tensor if not self.training else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names) else self.noop_cat("weight_tensor", self.weight_names)
) )
if self.training: if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply fwd_fn = _LayerNormLinear.apply
args = [] args = []
else: else:
...@@ -1336,7 +1336,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1336,7 +1336,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
self.return_layernorm_output, self.return_layernorm_output,
self.training, torch.is_grad_enabled(),
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
) )
...@@ -1380,7 +1380,7 @@ class _Linear(torch.autograd.Function): ...@@ -1380,7 +1380,7 @@ class _Linear(torch.autograd.Function):
tensor_parallel: bool, tensor_parallel: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
is_training: bool, is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -1397,7 +1397,7 @@ class _Linear(torch.autograd.Function): ...@@ -1397,7 +1397,7 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not fp8_meta["recipe"].override_linear_precision.wgrad: if not fp8_meta["recipe"].override_linear_precision.wgrad:
if is_training: if is_grad_enabled:
inputmat, inputmat_t = fp8_cast_transpose_fused( inputmat, inputmat_t = fp8_cast_transpose_fused(
inputmat, inputmat,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -1434,7 +1434,7 @@ class _Linear(torch.autograd.Function): ...@@ -1434,7 +1434,7 @@ class _Linear(torch.autograd.Function):
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights: if update_fp8_weights:
if is_training: if is_grad_enabled:
fp8_cast_transpose_fused( fp8_cast_transpose_fused(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -1489,7 +1489,7 @@ class _Linear(torch.autograd.Function): ...@@ -1489,7 +1489,7 @@ class _Linear(torch.autograd.Function):
use_bias=use_bias, use_bias=use_bias,
) )
if is_training: if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
ctx.save_for_backward( ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None, inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
...@@ -1916,17 +1916,17 @@ class Linear(TransformerEngineBaseModule): ...@@ -1916,17 +1916,17 @@ class Linear(TransformerEngineBaseModule):
bias_tensor = ( bias_tensor = (
bias if bias is not None bias if bias is not None
else self.bias if self.parameters_split is None else self.bias if self.parameters_split is None
else self.bias_tensor if not self.training else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names) else self.noop_cat("bias_tensor", self.bias_names)
) )
weight_tensor = ( weight_tensor = (
weight if weight is not None weight if weight is not None
else self.weight if self.parameters_split is None else self.weight if self.parameters_split is None
else self.weight_tensor if not self.training else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names) else self.noop_cat("weight_tensor", self.weight_names)
) )
if self.training: if torch.is_grad_enabled():
linear_fn = _Linear.apply linear_fn = _Linear.apply
args = [] args = []
else: else:
...@@ -1949,7 +1949,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1949,7 +1949,7 @@ class Linear(TransformerEngineBaseModule):
self.tp_size > 1, self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
self.training, torch.is_grad_enabled(),
) )
out = linear_fn(*args) out = linear_fn(*args)
...@@ -1994,7 +1994,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1994,7 +1994,7 @@ class _LayerNormMLP(torch.autograd.Function):
return_layernorm_output: bool, return_layernorm_output: bool,
bias_gelu_nvfusion: bool, bias_gelu_nvfusion: bool,
set_parallel_mode: bool, set_parallel_mode: bool,
is_training: bool, is_grad_enabled: bool,
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
...@@ -2016,7 +2016,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2016,7 +2016,7 @@ class _LayerNormMLP(torch.autograd.Function):
if fp8: if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output: if not return_layernorm_output:
if is_training: if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8( ln_out, mu, rsigma = layernorm_fwd_fp8(
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -2048,7 +2048,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2048,7 +2048,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
else: else:
if is_training: if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd( ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
) )
...@@ -2074,7 +2074,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2074,7 +2074,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_bias else fc2_bias fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_bias else fc2_bias
if update_fp8_weights: if update_fp8_weights:
if is_training: if is_grad_enabled:
fp8_cast_transpose_fused( fp8_cast_transpose_fused(
fc1_weight, fc1_weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -2173,7 +2173,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2173,7 +2173,7 @@ class _LayerNormMLP(torch.autograd.Function):
gelu=not bias_gelu_nvfusion, gelu=not bias_gelu_nvfusion,
) )
if bias_gelu_nvfusion and is_training: if bias_gelu_nvfusion and is_grad_enabled:
fc1_out, _, _ = fc1_outputs fc1_out, _, _ = fc1_outputs
gelu_out = bias_gelu_fused(fc1_out, fc1_bias) gelu_out = bias_gelu_fused(fc1_out, fc1_bias)
else: else:
...@@ -2195,7 +2195,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2195,7 +2195,7 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc2_bias, bias=fc2_bias,
use_bias=use_bias, use_bias=use_bias,
) )
if is_training: if is_grad_enabled:
ctx.save_for_backward( ctx.save_for_backward(
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -2805,7 +2805,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2805,7 +2805,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
""" """
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
if self.training: if torch.is_grad_enabled():
fwd_fn = _LayerNormMLP.apply fwd_fn = _LayerNormMLP.apply
args = [] args = []
else: else:
...@@ -2837,7 +2837,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2837,7 +2837,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_layernorm_output, self.return_layernorm_output,
self.bias_gelu_nvfusion, self.bias_gelu_nvfusion,
self.set_parallel_mode, self.set_parallel_mode,
self.training, torch.is_grad_enabled(),
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment