Unverified Commit d3efaebb authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

Delete extra tensor objects after restoring float8 tensors (#1500)



* delete extra tensor objects after restoring float8 tensors
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* nit fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix the leak in float8tensor and mxfloat8tensor classes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* uncomment the fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 303c6d16
......@@ -448,6 +448,9 @@ class _LayerNormLinear(torch.autograd.Function):
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
......
......@@ -567,6 +567,10 @@ class _LayerNormMLP(torch.autograd.Function):
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
fc1_weight_main_grad = (
ctx.fc1_main_grad
......
......@@ -354,6 +354,9 @@ class _Linear(torch.autograd.Function):
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors)
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
......
......@@ -105,8 +105,8 @@ class Float8TensorBase:
"""
tensors = [self._data, self._transpose]
# self._data = None
# self._transpose = None
self._data = None
self._transpose = None
return tensors, self
def restore_from_saved(
......
......@@ -100,8 +100,8 @@ class MXFP8TensorBase:
"""
tensors = [self._rowwise_data, self._columnwise_data]
# self._rowwise_data = None
# self._columnwise_data = None
self._rowwise_data = None
self._columnwise_data = None
return tensors, self
def restore_from_saved(
......
......@@ -348,6 +348,15 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
self._transpose = torch.Tensor() if self._transpose is not None else None
self._transpose_invalid = True
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]:
"""Prepare the tensor base for saving for backward
After calling this, the tensor instance does not hold any
data.
"""
return [self], None
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
......
......@@ -285,6 +285,15 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]:
"""Prepare the tensor base for saving for backward
After calling this, the tensor instance does not hold any
data.
"""
return [self], None
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment