Unverified Commit 1bafd1a6 authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #681 from pengcheng888/issue/679_1

issue/679 - 减少创建和使用infini.tensor对象过程中的属性调用次数,以降低耗时
parents 1887c3f1 0c0489c5
...@@ -2,16 +2,20 @@ from infinicore.lib import _infinicore ...@@ -2,16 +2,20 @@ from infinicore.lib import _infinicore
class device: class device:
def __init__(self, type=None, index=None): # Public attributes describing the device
if type is None: type: str
type = "cpu" index: int
_underlying: _infinicore.Device
def __init__(self, type=None, index=None):
if isinstance(type, device): if isinstance(type, device):
self.type = type.type self.type = type.type
self.index = type.index self.index = type.index
return return
if type is None:
type = "cpu"
if ":" in type: if ":" in type:
if index is not None: if index is not None:
raise ValueError( raise ValueError(
...@@ -22,12 +26,14 @@ class device: ...@@ -22,12 +26,14 @@ class device:
index = int(index) index = int(index)
self.type = type self.type = type
self.index = index if index else 0
self.index = index
def __getattr__(self, name):
_type, _index = device._to_infinicore_device(type, index if index else 0) # Lazily construct and cache an attribute.
# such as, self._underlying .
self._underlying = _infinicore.Device(_type, _index) _type, _index = device._to_infinicore_device(self.type, self.index)
setattr(self, name, _infinicore.Device(_type, _index))
return getattr(self, name)
def __repr__(self): def __repr__(self):
return f"device(type='{self.type}'{f', index={self.index}' if self.index is not None else ''})" return f"device(type='{self.type}'{f', index={self.index}' if self.index is not None else ''})"
......
...@@ -14,30 +14,35 @@ from .utils import ( ...@@ -14,30 +14,35 @@ from .utils import (
class Tensor: class Tensor:
# Public attributes describing the device
_underlying: _infinicore.Tensor
_torch_ref: "torch.Tensor" # noqa: F821
shape: list[int]
dtype: infinicore.dtype
device: infinicore.device
def __init__(self, underlying, *, _torch_ref=None): def __init__(self, underlying, *, _torch_ref=None):
"""An internal method. Please do not use this directly.""" """An internal method. Please do not use this directly."""
self._underlying = underlying self._underlying = underlying
self._dtype = infinicore.dtype(self._underlying.dtype)
self._device = infinicore.device._from_infinicore_device(
self._underlying.device
)
self._torch_ref = _torch_ref self._torch_ref = _torch_ref
@property def __getattr__(self, name):
def shape(self): # Lazily construct and cache an attribute.
return self._underlying.shape # such as, self.shape, self.dtype, self.device .
if name == "shape":
@property setattr(self, name, getattr(self._underlying, name))
def dtype(self): elif name == "dtype":
return self._dtype setattr(self, name, infinicore.dtype(getattr(self._underlying, name)))
elif name == "device":
@property setattr(
def device(self): self,
return self._device name,
infinicore.device._from_infinicore_device(
getattr(self._underlying, name)
),
)
return getattr(self, name)
@property @property
def ndim(self): def ndim(self):
...@@ -101,6 +106,10 @@ class Tensor: ...@@ -101,6 +106,10 @@ class Tensor:
def __add__(self, other): def __add__(self, other):
return infinicore.add(self, other) return infinicore.add(self, other)
def __iadd__(self, other):
infinicore.add(self, other, out=self)
return self
def __matmul__(self, other): def __matmul__(self, other):
return infinicore.matmul(self, other) return infinicore.matmul(self, other)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment