Unverified Commit 4575a329 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[hotfix] ColoTensor pin_memory (#840)

parent 9f6f6569
......@@ -20,21 +20,27 @@ class ColoTensor(object):
dtype=None,
requires_grad=False,
pin_memory=False,
device=None,
torch_tensor=torch.empty(0),
):
self._size = size
self._dtype = dtype
self._requires_grad = requires_grad
self._pin_memory = pin_memory
self._device = device
self._torch_tensor = torch_tensor
def numel(self):
return sum(self._size)
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor):
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.pin_memory,
torch_tensor=tensor)
pin_memory=tensor.is_pinned(),
device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0))
return colo_t
def del_torch_tensor(self) -> None:
......@@ -42,12 +48,12 @@ class ColoTensor(object):
self._torch_tensor = torch.empty(self._size)
def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor == None or self._torch_tensor.numel() == 0:
print(self._size, type(self._size))
if self._torch_tensor.numel() == 0:
self._torch_tensor = torch.empty(*self._size,
dtype=self._dtype,
pin_memory=self._pin_memory,
requires_grad=self._requires_grad,
pin_memory=self._pin_memory)
device=self._device)
return self._torch_tensor
@classmethod
......@@ -67,7 +73,5 @@ class ColoTensor(object):
if kwargs is None:
kwargs = {}
kwargs = {
k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
}
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment