You need to sign in or sign up before continuing.
Unverified Commit a2caec1e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

fp8_autocast bug fix when switching from non-fp8 execution (#2)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1531dc78
......@@ -226,11 +226,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
weights. This function will iterate over those shapes and initialize
respective attributed named `weight1_fp8`, `weight2_fp8`, ...
"""
if not self.fp8:
return
for i, shape in enumerate(self.fp8_weight_shapes, start=1):
weight_cast_attr = f"weight{i}_fp8"
weight_transpose_attr = f"weight{i}_t_fp8"
if self.fp8:
if not hasattr(self, weight_cast_attr):
if (
hasattr(self, weight_cast_attr)
and getattr(self, weight_cast_attr).shape == shape
):
return
setattr(
self,
weight_cast_attr,
......@@ -240,7 +248,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
dtype=torch.int8,
),
)
if not hasattr(self, weight_transpose_attr):
setattr(
self,
weight_transpose_attr,
......@@ -251,9 +258,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
dtype=torch.int8,
),
)
else:
setattr(self, weight_cast_attr, torch.Tensor())
setattr(self, weight_transpose_attr, torch.Tensor())
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group."""
......@@ -483,8 +487,8 @@ class _LayerNormLinear(torch.autograd.Function):
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
weight: torch.Tensor,
weight_fp8: torch.Tensor,
weight_t_fp8: torch.Tensor,
weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: Union[torch.Tensor, None],
bias: torch.Tensor,
use_bias: bool,
eps: float,
......@@ -1030,8 +1034,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_weight,
self.layer_norm_bias,
weight if weight is not None else self.weight,
self.weight1_fp8,
self.weight1_t_fp8,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
bias_tensor,
self.use_bias,
self.eps,
......@@ -1072,8 +1076,8 @@ class _Linear(torch.autograd.Function):
def forward(
ctx,
weight: torch.Tensor,
weight_fp8: torch.Tensor,
weight_t_fp8: torch.Tensor,
weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: Union[torch.Tensor, None],
inp: torch.Tensor,
bias: torch.Tensor,
use_bias: bool,
......@@ -1548,8 +1552,8 @@ class Linear(TransformerEngineBaseModule):
out = _Linear.apply(
weight if weight is not None else self.weight,
self.weight1_fp8,
self.weight1_t_fp8,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
inp,
bias_tensor,
self.use_bias,
......@@ -1585,12 +1589,12 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
fc1_weight: torch.Tensor,
fc1_weight_fp8: torch.Tensor,
fc1_weight_t_fp8: torch.Tensor,
fc1_weight_fp8: Union[torch.Tensor, None],
fc1_weight_t_fp8: Union[torch.Tensor, None],
fc1_bias: torch.Tensor,
fc2_weight: torch.Tensor,
fc2_weight_fp8: torch.Tensor,
fc2_weight_t_fp8: torch.Tensor,
fc2_weight_fp8: Union[torch.Tensor, None],
fc2_weight_t_fp8: Union[torch.Tensor, None],
fc2_bias: torch.Tensor,
use_bias: bool,
eps: float,
......@@ -2336,12 +2340,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_weight,
self.layer_norm_bias,
self.fc1_weight,
self.weight1_fp8,
self.weight1_t_fp8,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
self.fc1_bias,
self.fc2_weight,
self.weight2_fp8,
self.weight2_t_fp8,
self.weight2_fp8 if self.fp8 else None,
self.weight2_t_fp8 if self.fp8 else None,
self.fc2_bias,
False, # use_bias set to False for RPL
self.eps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment