test_chunkv2.py 4.05 KB
Newer Older
1
2
3
import pytest
import torch
import torch.distributed as dist
4
from torch.distributed.distributed_c10d import _get_default_group
5
6

import colossalai
7
from colossalai.accelerator import get_accelerator
8
from colossalai.tensor import ColoParameter
9
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
10
11
from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.chunk import Chunk
12
13
14


def dist_sum(x):
15
    temp = torch.tensor([x], device=get_accelerator().get_current_device())
16
17
18
19
20
21
22
23
24
25
    dist.all_reduce(temp)
    return temp.item()


def add_param(param_list, param_cp_list, *args, **kwargs):
    param = ColoParameter(torch.randn(*args, **kwargs))
    param_list.append(param)
    param_cp_list.append(param.clone())


26
def check_equal(param, param_cp):
27
28
29
30
31
32
33
    if param.device != param_cp.device:
        temp = param.data.to(param_cp.device)
    else:
        temp = param.data
    return torch.equal(temp, param_cp.data)


34
35
36
@parameterize("init_device", [None, torch.device("cpu")])
@parameterize("keep_gathered", [True, False])
@parameterize("pin_memory", [True, False])
37
38
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
    world_size = torch.distributed.get_world_size()
39
    pg = _get_default_group()
40
41
    my_chunk = Chunk(
        chunk_size=1024,
42
        zero_group=pg,
43
44
45
46
47
48
        dtype=torch.float32,
        init_device=init_device,
        cpu_shard_init=True,
        keep_gathered=keep_gathered,
        pin_memory=pin_memory,
    )
49
50
51
52

    param_list = []
    param_cp_list = []

53
    add_param(param_list, param_cp_list, 8, 8, 8, device="cuda")
54
    add_param(param_list, param_cp_list, 4, 4)
55
    add_param(param_list, param_cp_list, 4, 8, 2, device="cuda")
56
57
58
59
60
61
    add_param(param_list, param_cp_list, 1, 1, 5)

    for param in param_list:
        my_chunk.append_tensor(param)
    assert my_chunk.utilized_size == 597
    for param, param_cp in zip(param_list, param_cp_list):
62
        check_equal(param, param_cp)
63
64
65
66
    my_chunk.close_chunk()

    if keep_gathered is False:
        assert my_chunk.cpu_shard.size(0) == 1024 // world_size
67
        assert my_chunk.device_type == "cpu"
68
        assert my_chunk.can_move
69
        my_chunk.shard_move(get_accelerator().get_current_device())
70
    else:
71
        assert my_chunk.cuda_global_chunk.size(0) == 1024
72
        assert my_chunk.device_type == "cuda"
73
74
75
76
77
78
79
        assert not my_chunk.can_move

    assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size
    flag = my_chunk.has_inf_or_nan
    assert not flag, "has_inf_or_nan is {}".format(flag)

    my_chunk.access_chunk()
80
    assert my_chunk.device_type == "cuda"
81
    for param, param_cp in zip(param_list, param_cp_list):
82
        check_equal(param, param_cp)
83

84
    assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
85
    my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE)
86
87
    assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 3
    assert my_chunk.tensor_state_cnter[TensorState.COMPUTE] == 1
88
89
90
91
    assert not my_chunk.can_release

    for param in param_list:
        my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
92
        my_chunk.tensor_trans_state(param, TensorState.HOLD_AFTER_BWD)
93
94
        my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)

95
    assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
96
97
    assert my_chunk.can_reduce
    my_chunk.reduce()
98
    assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
99
100
101

    if keep_gathered is False:
        assert my_chunk.cuda_shard.size(0) == 1024 // world_size
102
        assert my_chunk.device_type == "cuda"
103
104
        assert my_chunk.can_move
    else:
105
        assert my_chunk.cuda_global_chunk.size(0) == 1024
106
        assert my_chunk.device_type == "cuda"
107
108
109
110
        assert not my_chunk.can_move


def run_dist(rank, world_size, port):
111
    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
112
113
114
115
    exam_chunk_basic()


@pytest.mark.dist
116
@pytest.mark.parametrize("world_size", [1, 2, 4])
117
118
@rerun_if_address_is_in_use()
def test_chunk_function(world_size):
119
    spawn(run_dist, world_size)
120
121


122
if __name__ == "__main__":
123
    test_chunk_function(4)