test_2d.py 1.28 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import pytest
Frank Lee's avatar
Frank Lee committed
5
6
import torch
import torch.multiprocessing as mp
zbian's avatar
zbian committed
7
8

from colossalai.core import global_context as gpc
Frank Lee's avatar
Frank Lee committed
9
from colossalai.initialize import launch, get_default_parser
Frank Lee's avatar
Frank Lee committed
10
11
12
13
from checks_2d.check_layer_2d import check_linear, check_layernorm, check_attention, check_mlp, check_transformerlayer
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
from functools import partial

zbian's avatar
zbian committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

CONFIG = dict(
    parallel=dict(
        pipeline=dict(size=1),
        tensor=dict(
            size=4,
            mode='2d'
        )
    ),
)


def check_operations():
    check_AB()
    check_ABT()
    check_ATB()


def check_layer():
    check_linear()
    check_layernorm()
    check_attention()
    check_mlp()
    check_transformerlayer()


Frank Lee's avatar
Frank Lee committed
40
def check_layer_and_operation(rank, world_size):
Frank Lee's avatar
Frank Lee committed
41
    launch(config=CONFIG,
Frank Lee's avatar
Frank Lee committed
42
43
44
45
46
47
           rank=rank,
           world_size=world_size,
           host='localhost',
           port=29921,
           backend='nccl')

zbian's avatar
zbian committed
48
49
50
    check_operations()
    check_layer()
    gpc.destroy()
Frank Lee's avatar
Frank Lee committed
51
52
53
54
55
56
57
58
    torch.cuda.empty_cache()


@pytest.mark.dist
def test_2d():
    world_size = 4
    run_func = partial(check_layer_and_operation, world_size=world_size)
    mp.spawn(run_func, nprocs=world_size)
zbian's avatar
zbian committed
59
60
61
62


if __name__ == '__main__':
    test_2d()