test_chunk.py 3.15 KB
Newer Older
1
2
3
4
5
6
import torch
import colossalai
import pytest
import torch.multiprocessing as mp
from typing import List
from functools import partial
7
from colossalai.gemini import ChunkManager
8
9
from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port
10
from colossalai.tensor import ProcessGroup as ColoProcessGroup
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]):
    for p, has_tensor in zip(params, has_tensors):
        if has_tensor:
            assert p.storage().size() > 0
            assert p.device.type == 'cuda'
        else:
            assert p.storage().size() == 0


# HAS_TENSORS[use_chunk][use_zero]
HAS_TENSORS = {
    True: {
        True: [[True, True, False], [False, False, True]],
        False: [[True, True, True], [True, True, True]]
    },
    False: {
        True: [[True, False, True], [False, True, False]],
        False: [[True, True, True], [True, True, True]]
    }
}

34
TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}}
35

36
37
38
39

@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
def run_chunk_zero(use_chunk, use_zero):
40
41
    pg = ColoProcessGroup()
    rank = pg.rank()
42
43
    if rank == 0:
        print(f'use_chunk={use_chunk}, use_zero={use_zero}')
44
45
    params = [torch.rand(8, 8) for _ in range(3)]
    chunk_size = 128 if use_chunk else None
46
    chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
47
    chunk_manager.create_group('param')
48
49
    assert chunk_manager.total_mem['cpu'] == 0
    assert chunk_manager.total_mem['cuda'] == 0
50
51
52
    for p in params:
        chunk_manager.append_tensor(p, 'param')
    check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
53
54
    assert chunk_manager.total_mem['cpu'] == 0
    assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
55
56
57
    chunks = chunk_manager.get_chunks(params)
    for chunk in chunks:
        chunk_manager.access_chunk(chunk)
58
    check_has_params(params, [True, True, True])
59
60
    assert chunk_manager.total_mem['cpu'] == 0
    assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
61
62
    for chunk in chunks:
        chunk_manager.release_chunk(chunk)
63
    check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
64
65
    assert chunk_manager.total_mem['cpu'] == 0
    assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
66
67
    for chunk in chunks:
        chunk_manager.move_chunk(chunk, torch.device('cpu'))
68
69
    assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
    assert chunk_manager.total_mem['cuda'] == 0
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86


def run_dist(rank, world_size, port):
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    run_chunk_zero()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_chunk_mapping(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_chunk_mapping(2)