test_shard_param.py 3.16 KB
Newer Older
Jiarui Fang's avatar
Jiarui Fang committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from functools import partial

import colossalai
import pytest
import torch
import torch.multiprocessing as mp
Jiarui Fang's avatar
Jiarui Fang committed
10
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
Jiarui Fang's avatar
Jiarui Fang committed
11
12
13
14
from colossalai.utils import free_port
from colossalai.logging import get_dist_logger, disable_existing_loggers
from tests.test_zero_data_parallel.common import Net, CONFIG

Jiarui Fang's avatar
Jiarui Fang committed
15

Jiarui Fang's avatar
Jiarui Fang committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def run_shard_tensor(rank, world_size, port):
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))

    assert list(t.shape) == [world_size * 2, 3]
    t.shard()
    # The shape is flattened
    assert list(t.shape) == [6]
    # Do nothing
    t.shard()
    assert list(t.shape) == [6]

    t.gather()
    assert list(t.shape) == [world_size * 2, 3]

    t.payload = torch.zeros(world_size * 2, 3)
    assert torch.sum(t.payload).cpu() == 0


@pytest.mark.dist
def test_shard_tensor():
    world_size = 2
    run_func = partial(run_shard_tensor, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


Jiarui Fang's avatar
Jiarui Fang committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def run_init_shard_param(rank, world_size, port):
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    param = torch.nn.Parameter(data=torch.rand(2, 3))
    sparam = ShardedParam(param, None, True)
    payload = sparam.payload(torch.device('cuda'))
    assert (list(payload.shape) == [3])
    del sparam

    param_shape = (2, 3)
    sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
    payload = sparam.payload(torch.device('cuda'))
    assert (list(payload.shape) == [3])

    param_shape = (2, 3)
    sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
    payload = sparam.payload(torch.device('cuda'))
    assert (list(payload.shape) == [2, 3])


Jiarui Fang's avatar
Jiarui Fang committed
61
def run_shard_param_check(rank, world_size, port):
Jiarui Fang's avatar
Jiarui Fang committed
62
63
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')

Jiarui Fang's avatar
Jiarui Fang committed
64
65
66
67
68
69
    logger = get_dist_logger()
    model = Net()

    # add an attribute as ca_attr to hijack the access to param.data
    for _, param in model.named_parameters():
        numel_ref = (param.numel() + world_size - 1) // world_size
Jiarui Fang's avatar
Jiarui Fang committed
70
        param.ca_attr = ShardedParam(param)
Jiarui Fang's avatar
Jiarui Fang committed
71
72
        param.ca_attr.shard()
        param_data = param.ca_attr.payload(torch.device('cpu'))
Jiarui Fang's avatar
Jiarui Fang committed
73
        assert (numel_ref == param_data.numel())
Jiarui Fang's avatar
Jiarui Fang committed
74
75
76
77

    for _, param in model.named_parameters():
        param.ca_attr.gather()
        param_data = param.ca_attr.payload(torch.device('cpu'))
Jiarui Fang's avatar
Jiarui Fang committed
78

Jiarui Fang's avatar
Jiarui Fang committed
79
80
    disable_existing_loggers([logger])

Jiarui Fang's avatar
Jiarui Fang committed
81

Jiarui Fang's avatar
Jiarui Fang committed
82
@pytest.mark.dist
Jiarui Fang's avatar
Jiarui Fang committed
83
def test_shard_shape():
Jiarui Fang's avatar
Jiarui Fang committed
84
85
86
87
    world_size = 2
    run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)

Jiarui Fang's avatar
Jiarui Fang committed
88
89
90
91
92
93
94
95

@pytest.mark.dist
def test_init_shard_param():
    world_size = 2
    run_func = partial(run_init_shard_param, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


Jiarui Fang's avatar
Jiarui Fang committed
96
if __name__ == '__main__':
Jiarui Fang's avatar
Jiarui Fang committed
97
    test_shard_tensor()
Jiarui Fang's avatar
Jiarui Fang committed
98
99
    test_shard_shape()
    test_init_shard_param()