Unverified Commit 2f643ada authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Gradient enablement bug fix (#72)



fix use of PyTorch training flag
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent eda8f461
......@@ -665,7 +665,7 @@ class _LayerNormLinear(torch.autograd.Function):
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
is_training: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
......@@ -688,7 +688,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
if is_training:
if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
......@@ -711,7 +711,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward,
)
else:
if is_training:
if is_grad_enabled:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
)
......@@ -727,7 +727,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward,
)
else:
if is_training:
if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
)
......@@ -752,7 +752,7 @@ class _LayerNormLinear(torch.autograd.Function):
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if is_training:
if is_grad_enabled:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
......@@ -806,7 +806,7 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias=use_bias,
)
if is_training:
if is_grad_enabled:
ctx.save_for_backward(
inputmat,
ln_weight,
......@@ -1299,17 +1299,17 @@ class LayerNormLinear(TransformerEngineBaseModule):
bias_tensor = (
bias if bias is not None
else self.bias if self.parameters_split is None
else self.bias_tensor if not self.training
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names)
)
weight_tensor = (
weight if weight is not None
else self.weight if self.parameters_split is None
else self.weight_tensor if not self.training
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names)
)
if self.training:
if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
args = []
else:
......@@ -1336,7 +1336,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.training,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
)
......@@ -1380,7 +1380,7 @@ class _Linear(torch.autograd.Function):
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_training: bool,
is_grad_enabled: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
......@@ -1397,7 +1397,7 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not fp8_meta["recipe"].override_linear_precision.wgrad:
if is_training:
if is_grad_enabled:
inputmat, inputmat_t = fp8_cast_transpose_fused(
inputmat,
fp8_meta["scaling_fwd"],
......@@ -1434,7 +1434,7 @@ class _Linear(torch.autograd.Function):
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if is_training:
if is_grad_enabled:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
......@@ -1489,7 +1489,7 @@ class _Linear(torch.autograd.Function):
use_bias=use_bias,
)
if is_training:
if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
......@@ -1916,17 +1916,17 @@ class Linear(TransformerEngineBaseModule):
bias_tensor = (
bias if bias is not None
else self.bias if self.parameters_split is None
else self.bias_tensor if not self.training
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names)
)
weight_tensor = (
weight if weight is not None
else self.weight if self.parameters_split is None
else self.weight_tensor if not self.training
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names)
)
if self.training:
if torch.is_grad_enabled():
linear_fn = _Linear.apply
args = []
else:
......@@ -1949,7 +1949,7 @@ class Linear(TransformerEngineBaseModule):
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.training,
torch.is_grad_enabled(),
)
out = linear_fn(*args)
......@@ -1994,7 +1994,7 @@ class _LayerNormMLP(torch.autograd.Function):
return_layernorm_output: bool,
bias_gelu_nvfusion: bool,
set_parallel_mode: bool,
is_training: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
......@@ -2016,7 +2016,7 @@ class _LayerNormMLP(torch.autograd.Function):
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
if is_training:
if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
......@@ -2048,7 +2048,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
)
else:
if is_training:
if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
)
......@@ -2074,7 +2074,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_bias else fc2_bias
if update_fp8_weights:
if is_training:
if is_grad_enabled:
fp8_cast_transpose_fused(
fc1_weight,
fp8_meta["scaling_fwd"],
......@@ -2173,7 +2173,7 @@ class _LayerNormMLP(torch.autograd.Function):
gelu=not bias_gelu_nvfusion,
)
if bias_gelu_nvfusion and is_training:
if bias_gelu_nvfusion and is_grad_enabled:
fc1_out, _, _ = fc1_outputs
gelu_out = bias_gelu_fused(fc1_out, fc1_bias)
else:
......@@ -2195,7 +2195,7 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc2_bias,
use_bias=use_bias,
)
if is_training:
if is_grad_enabled:
ctx.save_for_backward(
inputmat,
ln_weight,
......@@ -2805,7 +2805,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
if self.training:
if torch.is_grad_enabled():
fwd_fn = _LayerNormMLP.apply
args = []
else:
......@@ -2837,7 +2837,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_layernorm_output,
self.bias_gelu_nvfusion,
self.set_parallel_mode,
self.training,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment