test_tensor_utils.py 3.38 KB
Newer Older
1
2
import pytest

Jiarui Fang's avatar
Jiarui Fang committed
3
import colossalai
4
from colossalai.utils.cuda import get_current_device
5
6
7
8
from colossalai.gemini.tensor_utils import (colo_tensor_mem_usage, colo_model_data_tensor_move,
                                            colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
                                            colo_model_tensor_clone)
from colossalai.gemini.stateful_tensor import StatefulTensor
9
from colossalai.utils import free_port
10
from colossalai.testing import rerun_if_address_is_in_use
11
12
13
14
15
16

import torch

from functools import partial
import torch.multiprocessing as mp

17

18
19
20
def _run_colo_tensor_mem_usage():
    for i in range(1):
        if i == 1:
21
22
23
24
25
26
            t1 = StatefulTensor(torch.randn(2, 2))
            t2 = StatefulTensor(torch.randn(4, 4))
            c1, g1 = colo_tensor_mem_usage(t1)
            c2, g2 = colo_tensor_mem_usage(t2)
            assert c1 * 4 == c2
            assert g1 * 4 == g2
27
        else:
28
29
30
31
32
33
34
35
            t1 = torch.randn(2, 2)
            t2 = torch.randn(4, 4)
            c1, g1 = colo_tensor_mem_usage(t1)
            c2, g2 = colo_tensor_mem_usage(t2)
            assert c1 * 4 == c2
            assert g1 * 4 == g2


36
def _run_colo_model_data_tensor_move_inline():
37
38
39
40
    for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]:
        colo_model_data_tensor_move_inline(t, get_current_device())
        assert t.device == get_current_device()

41
42

def _run_colo_model_data_tensor_move():
43
44
    for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))),
              (torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]:
45
46
        cpu_t, cuda_t = t
        colo_model_data_tensor_move(cpu_t, cuda_t)
47
48
        assert cuda_t.device == get_current_device()

49
50

def _run_colo_model_data_move_to_cpu():
51
    for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]:
52
53
54
        colo_model_data_move_to_cpu(t)
        assert t.device == torch.device("cpu")

55

56
def _run_colo_model_tensor_clone():
57
58
59
60
    for t in [
            StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())),
            torch.randn(4, 4).cuda(torch.cuda.current_device())
    ]:
61
        if issubclass(type(t), StatefulTensor):
62
            assert t.payload.device == get_current_device()
63
        else:
64
65
66
            assert t.device == get_current_device()
        p = colo_model_tensor_clone(t, get_current_device())
        assert p.device == get_current_device()
67
68
69
70
71
72
73
74
75
        for i in range(2):
            for j in range(2):
                if issubclass(type(t), StatefulTensor):
                    assert t.payload.device == p.device
                    assert t.payload[i][j] == p[i][j]
                else:
                    assert t.device == p.device
                    assert t[i][j] == p[i][j]

76
77
78

def run_dist(rank, world_size, port):
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
79
80

    _run_colo_tensor_mem_usage()
81
82
    _run_colo_model_data_tensor_move_inline()
    _run_colo_model_data_tensor_move()
83
84
    _run_colo_model_data_move_to_cpu()
    _run_colo_model_tensor_clone()
85

86

87
@pytest.mark.dist
Frank Lee's avatar
Frank Lee committed
88
@pytest.mark.parametrize("world_size", [2, 4])
89
@rerun_if_address_is_in_use()
90
def test_zero_tensor_utils(world_size):
91
92
93
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)

94

95
if __name__ == '__main__':
96
    test_zero_tensor_utils(world_size=2)