test_tensor.py 3.6 KB
Newer Older
1
import torch
ver217's avatar
ver217 committed
2
import pytest
3
4
5
from colossalai.tensor import ColoTensor
from numpy import allclose

6
7
import colossalai
from colossalai.utils import free_port
8
from colossalai.tensor import distspec, ColoTensorSpec
9
10
11
12
from colossalai.core import global_context as gpc
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
13
from colossalai.tensor import distspec, ColoTensor, ProcessGroup
14
15
from functools import partial

16

17
18
def _run_tensor_indexing():
    pg = ProcessGroup()
19
    torch_t = torch.randn(2, 3)
20
    colo_t = ColoTensor(torch_t, ColoTensorSpec(pg))
ver217's avatar
ver217 committed
21
    assert allclose(torch_t[:, 1], colo_t[:, 1])
22
23


24
25
def _run_wrapped_tensor_func():
    pg = ProcessGroup()
26
    t_ref = torch.randn(4, 5)
27
    t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
28
29
30
31
32
33

    # non-func attr
    assert t.is_cuda == t_ref.is_cuda

    # return 1 torch.Tensor
    t_abs = t.abs()
ver217's avatar
ver217 committed
34
    assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs())
35
36
37
38
39

    # return 1 non-torch.Tensor
    assert t.dim() == t_ref.dim()

    # return >1 torch.Tensor
40
    assert isinstance(t, ColoTensor)
41
    t_split1, t_split2 = t.split(2)
42
    assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}"
43
44


45
46
def _run_operand():
    pg = ProcessGroup()
47
    t_ref = torch.randn(4, 5)
48
    t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
49
50
51
52

    t_ref_res = t_ref + t_ref
    t_res = t + t
    assert torch.allclose(t_ref_res, t_res)
53

54
55
56
57

#### Test Distributed init a Colotensor


58
59
def _run_view(world_size):
    t_ref = torch.randn(4, 5)
60
    rank = gpc.get_global_rank()
61
    pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
62
    t = ColoTensor.from_torch_tensor(
63
        t_ref, ColoTensorSpec(pg, dist_attr=distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()])))
64

65
66
67
    assert t.size_global()[0] == 4 * world_size
    assert t.size_global(1) == 5
    assert t.size_global() == torch.Size([4 * world_size, 5])
68

69
    t = t.view_global(4 * 5 * world_size)
70
71
72
    assert t.shape == torch.Size([4 * 5 * world_size])


73
74
def _run_tensor_shard_init(world_size):
    t_ref = torch.randn(4, 5)
75
76
77
    pg = ProcessGroup(tp_degree=world_size)
    shard_attr = distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()])
    tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
78
    t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
79
    t.set_dist_spec(distspec.replicate())
80
    assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
81
82
83
84


def _run_tensor_replicated_init(world_size):
    t_ref = torch.randn(4 * world_size, 5)
85
86
87
    pg = ProcessGroup()
    spec = ColoTensorSpec(pg)
    t = ColoTensor.from_torch_tensor(t_ref.clone(), spec)
88
89
90
91

    assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"


92
93
94
95
96
97
98
def _run_process_group(world_size):
    pg1 = ProcessGroup()
    pg2 = ProcessGroup()

    assert pg1 == pg2


99
def run_dist_tests(rank, world_size, port):
100
101
102
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    _run_tensor_shard_init(world_size)
    _run_tensor_replicated_init(world_size)
103
    _run_view(world_size)
104
    _run_process_group(world_size)
105
106
107
108
    _run_tensor_indexing()
    # TODO not passed
    # _run_wrapped_tensor_func()
    _run_operand()
109
110
111
112
113


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
114
def test_dist_cases(world_size):
115
    run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
116
117
118
119
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
120
    test_dist_cases(2)