Unverified Commit 82555b3f authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Linear and LayerNormLinear bug fix for excess weight and bias buffers (#570)



* Linear and LayerNormLinear weight and bias buffer cleanup at the end of init when there is no parameter split
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed typo in tensor name
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed typo in tensor name
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent e7261e11
......@@ -880,6 +880,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
# Clean up weight and bias buffers
if self.parameters_split is None:
del self.weight_tensor
if self.use_bias:
del self.bias_tensor
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
if not self.zero_centered_gamma:
......
......@@ -767,6 +767,12 @@ class Linear(TransformerEngineBaseModule):
else:
self.gemm_bias_unfused_add = False
# Clean up weight and bias buffers
if self.parameters_split is None:
del self.weight_tensor
if self.use_bias:
del self.bias_tensor
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment