Unverified Commit a2caec1e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

fp8_autocast bug fix when switching from non-fp8 execution (#2)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1531dc78
...@@ -226,11 +226,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -226,11 +226,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
weights. This function will iterate over those shapes and initialize weights. This function will iterate over those shapes and initialize
respective attributed named `weight1_fp8`, `weight2_fp8`, ... respective attributed named `weight1_fp8`, `weight2_fp8`, ...
""" """
if not self.fp8:
return
for i, shape in enumerate(self.fp8_weight_shapes, start=1): for i, shape in enumerate(self.fp8_weight_shapes, start=1):
weight_cast_attr = f"weight{i}_fp8" weight_cast_attr = f"weight{i}_fp8"
weight_transpose_attr = f"weight{i}_t_fp8" weight_transpose_attr = f"weight{i}_t_fp8"
if self.fp8:
if not hasattr(self, weight_cast_attr): if (
hasattr(self, weight_cast_attr)
and getattr(self, weight_cast_attr).shape == shape
):
return
setattr( setattr(
self, self,
weight_cast_attr, weight_cast_attr,
...@@ -240,7 +248,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -240,7 +248,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
dtype=torch.int8, dtype=torch.int8,
), ),
) )
if not hasattr(self, weight_transpose_attr):
setattr( setattr(
self, self,
weight_transpose_attr, weight_transpose_attr,
...@@ -251,9 +258,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -251,9 +258,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
dtype=torch.int8, dtype=torch.int8,
), ),
) )
else:
setattr(self, weight_cast_attr, torch.Tensor())
setattr(self, weight_transpose_attr, torch.Tensor())
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group.""" """Set TP group."""
...@@ -483,8 +487,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -483,8 +487,8 @@ class _LayerNormLinear(torch.autograd.Function):
ln_weight: torch.Tensor, ln_weight: torch.Tensor,
ln_bias: torch.Tensor, ln_bias: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
weight_fp8: torch.Tensor, weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: torch.Tensor, weight_t_fp8: Union[torch.Tensor, None],
bias: torch.Tensor, bias: torch.Tensor,
use_bias: bool, use_bias: bool,
eps: float, eps: float,
...@@ -1030,8 +1034,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1030,8 +1034,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_weight, self.layer_norm_weight,
self.layer_norm_bias, self.layer_norm_bias,
weight if weight is not None else self.weight, weight if weight is not None else self.weight,
self.weight1_fp8, self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8, self.weight1_t_fp8 if self.fp8 else None,
bias_tensor, bias_tensor,
self.use_bias, self.use_bias,
self.eps, self.eps,
...@@ -1072,8 +1076,8 @@ class _Linear(torch.autograd.Function): ...@@ -1072,8 +1076,8 @@ class _Linear(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
weight: torch.Tensor, weight: torch.Tensor,
weight_fp8: torch.Tensor, weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: torch.Tensor, weight_t_fp8: Union[torch.Tensor, None],
inp: torch.Tensor, inp: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
use_bias: bool, use_bias: bool,
...@@ -1548,8 +1552,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -1548,8 +1552,8 @@ class Linear(TransformerEngineBaseModule):
out = _Linear.apply( out = _Linear.apply(
weight if weight is not None else self.weight, weight if weight is not None else self.weight,
self.weight1_fp8, self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8, self.weight1_t_fp8 if self.fp8 else None,
inp, inp,
bias_tensor, bias_tensor,
self.use_bias, self.use_bias,
...@@ -1585,12 +1589,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1585,12 +1589,12 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight: torch.Tensor, ln_weight: torch.Tensor,
ln_bias: torch.Tensor, ln_bias: torch.Tensor,
fc1_weight: torch.Tensor, fc1_weight: torch.Tensor,
fc1_weight_fp8: torch.Tensor, fc1_weight_fp8: Union[torch.Tensor, None],
fc1_weight_t_fp8: torch.Tensor, fc1_weight_t_fp8: Union[torch.Tensor, None],
fc1_bias: torch.Tensor, fc1_bias: torch.Tensor,
fc2_weight: torch.Tensor, fc2_weight: torch.Tensor,
fc2_weight_fp8: torch.Tensor, fc2_weight_fp8: Union[torch.Tensor, None],
fc2_weight_t_fp8: torch.Tensor, fc2_weight_t_fp8: Union[torch.Tensor, None],
fc2_bias: torch.Tensor, fc2_bias: torch.Tensor,
use_bias: bool, use_bias: bool,
eps: float, eps: float,
...@@ -2336,12 +2340,12 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2336,12 +2340,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_weight, self.layer_norm_weight,
self.layer_norm_bias, self.layer_norm_bias,
self.fc1_weight, self.fc1_weight,
self.weight1_fp8, self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8, self.weight1_t_fp8 if self.fp8 else None,
self.fc1_bias, self.fc1_bias,
self.fc2_weight, self.fc2_weight,
self.weight2_fp8, self.weight2_fp8 if self.fp8 else None,
self.weight2_t_fp8, self.weight2_t_fp8 if self.fp8 else None,
self.fc2_bias, self.fc2_bias,
False, # use_bias set to False for RPL False, # use_bias set to False for RPL
self.eps, self.eps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment