Unverified Commit 2ecc3d7a authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[tensor] lazy init (#823)

parent 68dcd51d
import torch import torch
from .op_wrapper import _COLOSSAL_OPS from .op_wrapper import _COLOSSAL_OPS
from typing import Tuple
class ColoTensor(object): class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
2. It supports lazy init the tensor's payload.
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
"""
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
return super(ColoTensor, cls).__new__(cls) return super(ColoTensor, cls).__new__(cls)
def __init__(self, t: torch.Tensor) -> None: def __init__(
self._torch_tensor = t self,
*size: Tuple[int],
dtype=None,
requires_grad=False,
pin_memory=False,
torch_tensor=None,
):
self._size = size
self._dtype = dtype
self._requires_grad = requires_grad
self._pin_memory = pin_memory
self._torch_tensor = torch_tensor
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor):
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.pin_memory,
torch_tensor=tensor)
return colo_t
def torch_tensor(self) -> torch.Tensor: def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor == None:
self._torch_tensor = torch.empty(*self._size,
dtype=self._dtype,
requires_grad=self._requires_grad,
pin_memory=self._pin_memory)
return self._torch_tensor return self._torch_tensor
@classmethod @classmethod
......
from numpy import allclose from numpy import allclose, require
import torch import torch
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
from copy import deepcopy from copy import deepcopy
...@@ -14,8 +14,8 @@ def test_linear(): ...@@ -14,8 +14,8 @@ def test_linear():
input_ref = torch.randn(1, in_dim) input_ref = torch.randn(1, in_dim)
input_tensor = input_ref.clone() input_tensor = input_ref.clone()
sharded_weight = ColoTensor(fc_ref.weight) sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
sharded_bias = ColoTensor(fc_ref.bias) sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
# replace the torch nn.Parameters with ShardedTensor # replace the torch nn.Parameters with ShardedTensor
delattr(fc, 'weight') delattr(fc, 'weight')
...@@ -48,7 +48,7 @@ def test_linear(): ...@@ -48,7 +48,7 @@ def test_linear():
def test_element_wise(): def test_element_wise():
t_ref = torch.randn(3, 5) t_ref = torch.randn(3, 5)
t = ColoTensor(t_ref.clone()) t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.mean(t) == torch.mean(t_ref) assert torch.mean(t) == torch.mean(t_ref)
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref)) assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref)) assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
...@@ -57,10 +57,16 @@ def test_element_wise(): ...@@ -57,10 +57,16 @@ def test_element_wise():
# Test a function not wrapped by # Test a function not wrapped by
def test_no_wrap_op(): def test_no_wrap_op():
t_ref = torch.randn(3, 5) t_ref = torch.randn(3, 5)
t = ColoTensor(t_ref.clone()) t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.sum(t) == torch.sum(t_ref) assert torch.sum(t) == torch.sum(t_ref)
def test_lazy_init_tensor():
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor == None
assert lazy_t.torch_tensor().numel() == 6
if __name__ == '__main__': if __name__ == '__main__':
test_no_wrap_op() test_lazy_init_tensor()
# test_element_wise() # test_element_wise()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment