test_lookup_buffer.py 4.42 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
import os
import random

import torch
from tqdm import tqdm

from vllm.config import KVTransferConfig
11
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import SimpleBuffer
12
13
14
15
16
17
18
19
20
21
22
23
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe

# TODO: the test depends on a lot of fields in the current implementation.
# We should have standard interface instead direct field access


def test_run(my_rank, buffer, device):
    # buffer should be empty in the beginning
    if my_rank == 0:
        assert buffer.buffer_size == 0
        assert len(buffer.buffer) == 0

24
    print(f"My rank: {my_rank}, device: {device}")
25
26
27

    # insert
    tokens = torch.tensor([1, 2, 3]).to(device)
28
    roi = tokens > 0
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    if my_rank == 0:
        key = 2.0 * torch.ones([5, 6]).to(device)
        value = 3.0 * torch.ones([5, 6]).to(device)

        placeholder = torch.tensor([1]).to(device)

        buffer.insert(tokens, roi, key, value, placeholder)

    torch.distributed.barrier()

    # drop_select
    if my_rank == 1:
        tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi)
        assert torch.allclose(tokens, tok)
        assert torch.allclose(roi, roi_)
        assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device))
        assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device))
    torch.distributed.barrier()

    if my_rank == 0:
        assert buffer.buffer_size == 0
        assert len(buffer.buffer) == 0

52
    print(f"My rank: {my_rank}, Test run passed!")
53
54
55
56
57
58
59
60
61
62
63
64
65


def stress_test(my_rank, buf, device):
    torch.distributed.barrier()
    torch.manual_seed(100)

    reqs = [
        (
            torch.rand(100).to(device),  # tokens
            torch.ones(100).bool().to(device),  # roi
            torch.rand(100).to(device),  # key
            torch.rand(100).to(device),  # value
            torch.rand(100).to(device),  # hidden
66
67
        )
        for i in tqdm(range(200))
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    ]

    random.seed(my_rank)
    random.shuffle(reqs)

    torch.distributed.barrier()

    n = 0

    # the buffer size can only store 100 reqs
    # so the sender will occasionally block to wait for the receiver.
    for req in tqdm(reqs):
        if my_rank == 0:
            buf.insert(*req)
        else:
            tok, roi, k, v, h = req
            tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi)

            if tok_ is None:
                assert roi_ is None
                assert k_ is None
                assert v_ is None
                assert h_ is None
                n += 1
            else:
                assert torch.allclose(tok, tok_)
                assert torch.allclose(roi, roi_)
                assert torch.allclose(k, k_)
                assert torch.allclose(v, v_)
                assert torch.allclose(h, h_)
98
    print(f"Rank {my_rank} done")
99
100
101
102
103
104
105
106
107
108
109
110
111
    torch.distributed.barrier()

    if my_rank == 0:
        x = torch.tensor([0])
        torch.distributed.recv(x, 1)
        # the # of None received is the kv that are not selected
        assert x.item() == len(buf.buffer)
        # and the size of the buffer should be 2000 * buffer len
        print(buf.buffer_size)
        assert buf.buffer_size == 1700 * len(buf.buffer)
    else:
        torch.distributed.send(torch.tensor([n]), 0)

112
    print(f"My rank: {my_rank}, Passed stress test!")
113
114
115


if __name__ == "__main__":
116
    my_rank = int(os.environ["RANK"])
117
118

    torch.distributed.init_process_group(
119
120
        backend="gloo",
        init_method="tcp://localhost:12398",
121
122
123
124
        world_size=2,
        rank=my_rank,
    )

125
    print(f"initialized! My rank is {my_rank}")
126
127

    config = KVTransferConfig(
128
129
        kv_connector="P2pNcclConnector",
        kv_buffer_device="cuda",
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        kv_buffer_size=1e9,
        kv_rank=my_rank,
        kv_role="kv_both",  # this arg doesn't matter in this test
        kv_parallel_size=2,
        kv_ip="127.0.0.1",
        kv_port=12345,
    )

    data_pipe = PyNcclPipe(
        local_rank=my_rank,
        config=config,
        device="cuda",
        port_offset=0,
    )
    cpu_pipe = PyNcclPipe(
        local_rank=my_rank,
        config=config,
        device="cpu",
        port_offset=1,
    )

    buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000)

    test_run(my_rank, buffer, data_pipe.device)

    stress_test(my_rank, buffer, data_pipe.device)

    buffer.close()
    data_pipe.close()
    cpu_pipe.close()
160
    print("Done")