test_shard_param.py 2.39 KB
Newer Older
Jiarui Fang's avatar
Jiarui Fang committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from functools import partial

import colossalai
import pytest
import torch
import torch.multiprocessing as mp
Jiarui Fang's avatar
Jiarui Fang committed
10
from colossalai.zero.sharded_param import ShardedParam
Jiarui Fang's avatar
Jiarui Fang committed
11
12
13
14
from colossalai.utils import free_port
from colossalai.logging import get_dist_logger, disable_existing_loggers
from tests.test_zero_data_parallel.common import Net, CONFIG

Jiarui Fang's avatar
Jiarui Fang committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

def run_init_shard_param(rank, world_size, port):
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    param = torch.nn.Parameter(data=torch.rand(2, 3))
    sparam = ShardedParam(param, None, True)
    payload = sparam.payload(torch.device('cuda'))
    assert (list(payload.shape) == [3])
    del sparam

    param_shape = (2, 3)
    sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
    payload = sparam.payload(torch.device('cuda'))
    assert (list(payload.shape) == [3])

    param_shape = (2, 3)
    sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
    payload = sparam.payload(torch.device('cuda'))
    assert (list(payload.shape) == [2, 3])


Jiarui Fang's avatar
Jiarui Fang committed
35
def run_shard_param_check(rank, world_size, port):
Jiarui Fang's avatar
Jiarui Fang committed
36
37
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')

Jiarui Fang's avatar
Jiarui Fang committed
38
39
40
41
42
43
    logger = get_dist_logger()
    model = Net()

    # add an attribute as ca_attr to hijack the access to param.data
    for _, param in model.named_parameters():
        numel_ref = (param.numel() + world_size - 1) // world_size
Jiarui Fang's avatar
Jiarui Fang committed
44
        param.ca_attr = ShardedParam(param)
Jiarui Fang's avatar
Jiarui Fang committed
45
46
        param.ca_attr.shard()
        param_data = param.ca_attr.payload(torch.device('cpu'))
Jiarui Fang's avatar
Jiarui Fang committed
47
        assert (numel_ref == param_data.numel())
Jiarui Fang's avatar
Jiarui Fang committed
48
49
50
51

    for _, param in model.named_parameters():
        param.ca_attr.gather()
        param_data = param.ca_attr.payload(torch.device('cpu'))
Jiarui Fang's avatar
Jiarui Fang committed
52

Jiarui Fang's avatar
Jiarui Fang committed
53
54
    disable_existing_loggers([logger])

Jiarui Fang's avatar
Jiarui Fang committed
55

Jiarui Fang's avatar
Jiarui Fang committed
56
@pytest.mark.dist
Jiarui Fang's avatar
Jiarui Fang committed
57
def test_shard_shape():
Jiarui Fang's avatar
Jiarui Fang committed
58
59
60
61
    world_size = 2
    run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)

Jiarui Fang's avatar
Jiarui Fang committed
62
63
64
65
66
67
68
69

@pytest.mark.dist
def test_init_shard_param():
    world_size = 2
    run_func = partial(run_init_shard_param, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


Jiarui Fang's avatar
Jiarui Fang committed
70
if __name__ == '__main__':
Jiarui Fang's avatar
Jiarui Fang committed
71
72
    test_shard_shape()
    test_init_shard_param()