Unverified Commit b9e7b0b8 authored by Li Tao's avatar Li Tao Committed by GitHub
Browse files

Cache torch.Tensor() to reduce CPU overhead (#1759)



* use lru to cache torch.Tensor()
Signed-off-by: default avatarlit <lit@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove duplicated definition
Signed-off-by: default avatarlit <lit@nvidia.com>

* Update transformer_engine/pytorch/tensor/utils.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarlit <lit@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 421084cf
......@@ -3,13 +3,13 @@
# See LICENSE for license information.
"""Python interface for GEMM extensions"""
import functools
from typing import Iterable, Optional, Tuple, Union, List
import os
import torch
import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import get_sm_count
from ..utils import get_sm_count, _empty_tensor
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
......@@ -21,12 +21,6 @@ __all__ = [
]
@functools.lru_cache(maxsize=None)
def _empty_tensor() -> torch.Tensor:
"""Get tensor with no entries and no data"""
return torch.Tensor().cuda()
def general_gemm(
A: torch.Tensor,
B: torch.Tensor,
......
......@@ -18,6 +18,8 @@ from ...constants import TE_DType_To_Torch
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
class Float8BlockwiseQTensorBase(QuantizedTensorBase):
"""Mixin class that holds data attributes of Float8BlockwiseQTensor.
......@@ -68,7 +70,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
self._columnwise_scale_inv,
):
if t is not None:
t.data = torch.Tensor()
t.data = _empty_tensor()
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
......
......@@ -18,7 +18,7 @@ from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import is_non_tn_fp8_gemm_supported
from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor
class _FromFloat8Func(torch.autograd.Function):
......@@ -98,7 +98,7 @@ class Float8TensorBase(QuantizedTensorBase):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (self._data, self._transpose, self._scale_inv):
if t is not None:
t.data = torch.Tensor()
t.data = _empty_tensor()
self._transpose_invalid = True
def get_metadata(self) -> Dict[str, Any]:
......
......@@ -17,6 +17,8 @@ from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
class _FromMXFP8Func(torch.autograd.Function):
"""Cast from MXFP8 to other dtype"""
......@@ -92,7 +94,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
self._columnwise_scale_inv,
):
if t is not None:
t.data = torch.Tensor()
t.data = _empty_tensor()
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
......
......@@ -5,7 +5,6 @@
"""Helper functions for using fp8 tensors as weights"""
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
......
......@@ -24,6 +24,12 @@ def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
return False
@functools.lru_cache(maxsize=None)
def _empty_tensor() -> torch.Tensor:
"""Get tensor with no entries and no data"""
return torch.Tensor().cuda()
def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""
Trick to deallocate tensor memory when delete operation does not
......@@ -36,7 +42,7 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
if hasattr(t, "clear"):
t.clear()
else:
t.data = torch.Tensor()
t.data = _empty_tensor()
del t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment