test_p2p_communication.py 1.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import pytest
import torch
import torch.distributed as dist

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device


def check_p2p_communication():
    pg_mesh = ProcessGroupMesh(2)
    stage_manager = PipelineStageManager(pg_mesh, 0)
    p2p = PipelineP2PCommunication(stage_manager)

    rank = dist.get_rank()

    tensor = torch.ones(1, device=get_current_device())

    if rank == 0:
        p2p.send_forward(tensor)
        p2p.send_forward([tensor])
25
        p2p.send_forward({"tensor": tensor})
26
27
28
29
30
31
    else:
        obj = p2p.recv_forward()
        assert torch.equal(obj, tensor)
        obj = p2p.recv_forward()
        assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
        obj = p2p.recv_forward()
32
        assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
33
34
35
36

    if rank == 1:
        p2p.send_backward(tensor)
        p2p.send_backward([tensor])
37
        p2p.send_backward({"tensor": tensor})
38
39
40
41
42
43
    else:
        obj = p2p.recv_backward()
        assert torch.equal(obj, tensor)
        obj = p2p.recv_backward()
        assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
        obj = p2p.recv_backward()
44
        assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
45
46
47


def run_dist(rank, world_size, port):
48
    colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
49
50
51
52
53
54
55
56
57
    check_p2p_communication()


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_pipeline_p2p():
    spawn(run_dist, 2)


58
if __name__ == "__main__":
59
    test_pipeline_p2p()