test_tensor.py 474 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from colossalai.tensor import ColoTensor
from numpy import allclose


def test_tensor_indexing():
    torch_t = torch.randn(2, 3)
    colo_t = ColoTensor.init_from_torch_tensor(torch_t)
    assert allclose(torch_t[:, 1], colo_t[:, 1].torch_tensor())


def test_lazy_init_tensor():
    lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
    assert lazy_t._torch_tensor.numel() == 0
    assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()