Unverified Commit b9e7b0b8 authored by Li Tao's avatar Li Tao Committed by GitHub
Browse files

Cache torch.Tensor() to reduce CPU overhead (#1759)



* use lru to cache torch.Tensor()
Signed-off-by: default avatarlit <lit@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove duplicated definition
Signed-off-by: default avatarlit <lit@nvidia.com>

* Update transformer_engine/pytorch/tensor/utils.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarlit <lit@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 421084cf
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Python interface for GEMM extensions""" """Python interface for GEMM extensions"""
import functools
from typing import Iterable, Optional, Tuple, Union, List from typing import Iterable, Optional, Tuple, Union, List
import os import os
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_sm_count from ..utils import get_sm_count, _empty_tensor
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
...@@ -21,12 +21,6 @@ __all__ = [ ...@@ -21,12 +21,6 @@ __all__ = [
] ]
@functools.lru_cache(maxsize=None)
def _empty_tensor() -> torch.Tensor:
"""Get tensor with no entries and no data"""
return torch.Tensor().cuda()
def general_gemm( def general_gemm(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
......
...@@ -18,6 +18,8 @@ from ...constants import TE_DType_To_Torch ...@@ -18,6 +18,8 @@ from ...constants import TE_DType_To_Torch
from ..quantized_tensor import Quantizer from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
class Float8BlockwiseQTensorBase(QuantizedTensorBase): class Float8BlockwiseQTensorBase(QuantizedTensorBase):
"""Mixin class that holds data attributes of Float8BlockwiseQTensor. """Mixin class that holds data attributes of Float8BlockwiseQTensor.
...@@ -68,7 +70,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -68,7 +70,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
self._columnwise_scale_inv, self._columnwise_scale_inv,
): ):
if t is not None: if t is not None:
t.data = torch.Tensor() t.data = _empty_tensor()
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
......
...@@ -18,7 +18,7 @@ from ...constants import TE_DType as torch_to_transformer_engine_dtype ...@@ -18,7 +18,7 @@ from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer from ..quantized_tensor import Quantizer
from ...utils import is_non_tn_fp8_gemm_supported from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor
class _FromFloat8Func(torch.autograd.Function): class _FromFloat8Func(torch.autograd.Function):
...@@ -98,7 +98,7 @@ class Float8TensorBase(QuantizedTensorBase): ...@@ -98,7 +98,7 @@ class Float8TensorBase(QuantizedTensorBase):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully.""" """Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (self._data, self._transpose, self._scale_inv): for t in (self._data, self._transpose, self._scale_inv):
if t is not None: if t is not None:
t.data = torch.Tensor() t.data = _empty_tensor()
self._transpose_invalid = True self._transpose_invalid = True
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
......
...@@ -17,6 +17,8 @@ from ...constants import TE_DType as torch_to_transformer_engine_dtype ...@@ -17,6 +17,8 @@ from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
class _FromMXFP8Func(torch.autograd.Function): class _FromMXFP8Func(torch.autograd.Function):
"""Cast from MXFP8 to other dtype""" """Cast from MXFP8 to other dtype"""
...@@ -92,7 +94,7 @@ class MXFP8TensorBase(QuantizedTensorBase): ...@@ -92,7 +94,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
self._columnwise_scale_inv, self._columnwise_scale_inv,
): ):
if t is not None: if t is not None:
t.data = torch.Tensor() t.data = _empty_tensor()
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
"""Helper functions for using fp8 tensors as weights""" """Helper functions for using fp8 tensors as weights"""
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
......
...@@ -24,6 +24,12 @@ def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: ...@@ -24,6 +24,12 @@ def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
return False return False
@functools.lru_cache(maxsize=None)
def _empty_tensor() -> torch.Tensor:
"""Get tensor with no entries and no data"""
return torch.Tensor().cuda()
def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
""" """
Trick to deallocate tensor memory when delete operation does not Trick to deallocate tensor memory when delete operation does not
...@@ -36,7 +42,7 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: ...@@ -36,7 +42,7 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
if hasattr(t, "clear"): if hasattr(t, "clear"):
t.clear() t.clear()
else: else:
t.data = torch.Tensor() t.data = _empty_tensor()
del t del t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment