"vscode:/vscode.git/clone" did not exist on "e87557b069db2b945563c5d7a27ddc27b8361105"
test_send_recv.py 4.33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
import os
import time

import torch
from tqdm import tqdm

from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe


def test_run(my_rank, pipe):
15
    print(f"rank {my_rank} test_run starts....")
16
17
18
19
20
    # test run
    x = torch.tensor([1]).to(pipe.device)
    y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device)
    if my_rank == 0:
        pipe.send_tensor(x)
21
        print(f"rank {my_rank} sent tensor x")
22
        pipe.send_tensor(y)
23
        print(f"rank {my_rank} sent tensor y")
24
        x2 = pipe.recv_tensor()
25
        print(f"rank {my_rank} received x2 = ", x2)
26
        y2 = pipe.recv_tensor()
27
        print(f"rank {my_rank} received y2 = ", y2)
28
29
30

    else:
        x2 = pipe.recv_tensor()
31
        print(f"rank {my_rank} received x2 = ", x2)
32
        y2 = pipe.recv_tensor()
33
        print(f"rank {my_rank} received y2 = ", y2)
34
        pipe.send_tensor(x)
35
        print(f"rank {my_rank} sent tensor x")
36
        pipe.send_tensor(y)
37
        print(f"rank {my_rank} sent tensor y")
38
39
40
41

    assert torch.allclose(x, x2)
    assert torch.allclose(y, y2)

42
    print(f"rank {my_rank} test_run passed!")
43
44


45
46
def stress_test(my_rank, pipe):
    print(f"rank {my_rank} stress_test starts....")
47

48
    tensors: list[torch.Tensor] = []
49

50
    torch.distributed.barrier()
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    torch.manual_seed(0)

    for i in tqdm(range(500)):
        mean = torch.rand(1).item() * 100
        std = torch.rand(1).item() * 100
        size = torch.randint(900, 1000, (2, ))
        x = torch.normal(mean * 1.0, std * 1.0,
                         size=size.tolist()).to(pipe.device)

        # 5% probability of sending a None
        if torch.rand(1).item() < 0.05:
            tensors.append(None)
            tensors.append(None)
            tensors.append(None)
        else:
            tensors.append(x)
            tensors.append(x.mean().unsqueeze(0))
            tensors.append(x.std().unsqueeze(0))

    torch.distributed.barrier()

    for i in tqdm(range(500)):
        if my_rank == int((i % 10) > 3):
            pipe.send_tensor(tensors[3 * i])
            pipe.send_tensor(tensors[3 * i + 1])
            pipe.send_tensor(tensors[3 * i + 2])
        else:
            x = pipe.recv_tensor()
            mean = pipe.recv_tensor()
            std = pipe.recv_tensor()

            if x is None:
                assert mean is None
                assert std is None
            else:
                assert torch.allclose(x, tensors[3 * i])
                assert x.mean() == mean[0]
                assert x.std() == std[0]

        torch.distributed.barrier()


def latency_test(my_rank, pipe, nelement, ntensor):
    latencies = []

    torch.distributed.barrier()

    for i in tqdm(range(500)):

        tensors = []

        if my_rank == 0:
            # create tensor
            tensors = [
                torch.rand(nelement).to(pipe.device) for _ in range(ntensor)
            ]

        torch.distributed.barrier()

        if my_rank == 0:
            t = torch.tensor([time.time()],
                             dtype=torch.float64).to(pipe.device)
            for tensor in tensors:
                pipe.send_tensor(tensor)
            pipe.send_tensor(t)
        else:
            for _ in range(ntensor):
                pipe.recv_tensor()
            t = pipe.recv_tensor()
            latencies.append(time.time() - t.item())

    torch.distributed.barrier()

    print('Latency test passed.')
    print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms')


if __name__ == "__main__":

    my_rank = int(os.environ['RANK'])

    torch.distributed.init_process_group(
        backend='gloo',
        init_method='tcp://localhost:12398',
        world_size=2,
        rank=my_rank,
    )

    config = KVTransferConfig(
140
        kv_connector='P2pNcclConnector',
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        kv_buffer_device='cuda',
        kv_buffer_size=1e9,
        kv_rank=my_rank,
        kv_role="kv_both",  # this arg doesn't matter in this test
        kv_parallel_size=2,
        kv_ip="127.0.0.1",
        kv_port=12345,
    )

    pipe = PyNcclPipe(
        local_rank=my_rank,
        config=config,
    )

    test_run(my_rank, pipe)
156

157
158
159
160
    stress_test(my_rank, pipe)

    # Use this function if you want to test the latency of pipe impl.
    # latency_test(my_rank, pipe, 1024 * 8 * 128, 80)