Unverified Commit 94c92919 authored by Youngeun Kwon's avatar Youngeun Kwon Committed by GitHub
Browse files

Adding remove_caches API to Float8Tensor class (#1425)



* add remove_caches api
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* Update transformer_engine/pytorch/tensor/float8_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* explicit delete
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 9351a179
......@@ -334,6 +334,14 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""
self._transpose_invalid = True
def remove_caches(self) -> None:
"""
Remove transpose cache and mark it as invalid.
"""
self._transpose_invalid = True
del self._transpose # explicitly deletes the data for safety
self._transpose = None
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._data = torch.Tensor() if self._data is not None else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment