test_p2p_communication.py 2.53 KB
Newer Older
1
2
3
4
5
import pytest
import torch
import torch.distributed as dist

import colossalai
6
from colossalai.accelerator import get_accelerator
7
from colossalai.cluster import ProcessGroupMesh
8
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
9
10
11
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn

12
13
WORLD_SIZE = 2

14
15

def check_p2p_communication():
16
    pg_mesh = ProcessGroupMesh(WORLD_SIZE)
17
18
19
20
21
    stage_manager = PipelineStageManager(pg_mesh, 0)
    p2p = PipelineP2PCommunication(stage_manager)

    rank = dist.get_rank()

22
    tensor = torch.ones(1, device=get_accelerator().get_current_device())
23
24
25
26
27
28
    data = [
        "tensor",
        tensor,
        [tensor],
        {"tensor": tensor},
    ]
29
30

    if rank == 0:
31
32
33
        for obj in data:
            p2p.send_forward(obj)
        for i in range(len(data)):
34
            recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
35
36
37
38
39
40
41
42
43
            assert recv_obj == data[-(i + 1)]
    elif rank == 1:
        for obj in data:
            recv_obj = p2p.recv_forward()
            assert recv_obj == obj
        for i in range(len(data)):
            p2p.send_backward(data[-(i + 1)])
            recv_obj = p2p.recv_forward()
            assert recv_obj == data[i]
44
45

    if rank == 1:
46
47
48
        for obj in data:
            p2p.send_backward(obj)
        for i in range(len(data)):
49
            recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
50
51
52
53
54
55
56
57
58
59
60
61
62
63
            assert recv_obj == data[-(i + 1)]
    elif rank == 0:
        for obj in data:
            recv_obj = p2p.recv_backward()
            assert recv_obj == obj
        for i in range(len(data)):
            recv_obj = p2p.recv_backward()
            p2p.send_forward(data[-(i + 1)])
            assert recv_obj == data[i]

    if rank == 0:
        recv_obj = p2p.send_forward_recv_backward(
            tensor,
            send_metadata=False,
64
            metadata_recv=create_send_metadata(tensor),
65
66
67
        )
        assert recv_obj == tensor
    elif rank == 1:
68
        recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
69
70
        assert recv_obj == tensor
        p2p.send_backward(tensor, send_metadata=False)
71
72
73


def run_dist(rank, world_size, port):
74
    colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
75
76
77
78
79
80
    check_p2p_communication()


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_pipeline_p2p():
81
    spawn(run_dist, WORLD_SIZE)
82
83


84
if __name__ == "__main__":
85
    test_pipeline_p2p()