test_worker.py 4.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.offloading.abstract import LoadStoreSpec
from vllm.v1.offloading.worker.worker import (OffloadingHandler,
                                              OffloadingWorker, TransferResult,
                                              TransferSpec)


class LoadStoreSpec1(LoadStoreSpec):

    def __init__(self,
                 submit_success: bool = True,
                 async_success: bool = True,
                 exception: bool = False):
        self.finished = False
        self.submit_success = submit_success
        self.async_success = async_success
        self.exception = exception

    @staticmethod
    def medium() -> str:
        return "1"

    def __repr__(self):
        return f"{self.medium()}: {id(self)}"


class LoadStoreSpec2(LoadStoreSpec):

    @staticmethod
    def medium() -> str:
        return "2"

    def __repr__(self):
        return f"{self.medium()}: {id(self)}"


class OffloadingHandler1To2(OffloadingHandler):

    def __init__(self):
        self.transfers: dict[int, LoadStoreSpec1] = {}

    def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
        src, dst = spec
        assert isinstance(src, LoadStoreSpec1)
        assert isinstance(dst, LoadStoreSpec2)

        if src.exception:
            raise Exception("An expected exception. Don't worry!")
        if not src.submit_success:
            return False

        self.transfers[job_id] = src
        return True

    def get_finished(self) -> list[TransferResult]:
        finished = []
        for job_id, spec in list(self.transfers.items()):
            if spec.finished:
                finished.append((job_id, spec.async_success))
                del self.transfers[job_id]
        return finished


class OffloadingHandler2To1(OffloadingHandler):

    def __init__(self):
        self.transfers: dict[int, LoadStoreSpec1] = {}

    def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
        src, dst = spec
        assert isinstance(src, LoadStoreSpec2)
        assert isinstance(dst, LoadStoreSpec1)

        self.transfers[job_id] = dst
        return True

    def get_finished(self) -> list[TransferResult]:
        finished = []
        for job_id, spec in list(self.transfers.items()):
            if spec.finished:
                finished.append((job_id, spec.async_success))
                del self.transfers[job_id]
        return finished


def test_offloading_worker():
    """
    Tests OffloadingWorker with 2 handlers.
    One handler performs 1->2 transfers, and the other handles 2->1.
    """
    worker = OffloadingWorker()
    handler1to2 = OffloadingHandler1To2()
    handler2to1 = OffloadingHandler2To1()
    worker.register_handler(LoadStoreSpec1, LoadStoreSpec2, handler1to2)
    worker.register_handler(LoadStoreSpec2, LoadStoreSpec1, handler2to1)

    # 1st transfer 1->2 (exception)
    src1 = LoadStoreSpec1(exception=True)
    dst1 = LoadStoreSpec2()
    assert not worker.transfer_async(1, (src1, dst1))

    # 2ed transfer 1->2 (failure to submit)
    src2 = LoadStoreSpec1(submit_success=False)
    dst2 = LoadStoreSpec2()
    assert not worker.transfer_async(2, (src2, dst2))

    # 3rd transfer 1->2 (failure)
    src3 = LoadStoreSpec1(async_success=False)
    dst3 = LoadStoreSpec2()
    assert worker.transfer_async(3, (src3, dst3))

    # 4th transfer 1->2 (success)
    src4 = LoadStoreSpec1()
    dst4 = LoadStoreSpec2()
    worker.transfer_async(4, (src4, dst4))
    assert set(handler1to2.transfers.keys()) == {3, 4}

    # 5th transfer 2->1
    src5 = LoadStoreSpec2()
    dst5 = LoadStoreSpec1()
    worker.transfer_async(5, (src5, dst5))
    assert set(handler2to1.transfers.keys()) == {5}

    # no transfer completed yet
    assert worker.get_finished() == []

    # complete 3rd, 4th
    src3.finished = True
    src4.finished = True

    # 6th transfer 1->2
    src6 = LoadStoreSpec1()
    dst6 = LoadStoreSpec2()
    worker.transfer_async(6, (src6, dst6))

    # 7th transfer 2->1
    src7 = LoadStoreSpec2()
    dst7 = LoadStoreSpec1()
    worker.transfer_async(7, (src7, dst7))

    # 6th and 7th transfers started
    assert 6 in handler1to2.transfers
    assert 7 in handler2to1.transfers

    # verify result of 3rd and 4th transfers
    assert (sorted(worker.get_finished()) == [(3, False), (4, True)])

    # complete 6th and 7th transfers
    src6.finished = True
    dst7.finished = True
    assert (sorted(worker.get_finished()) == [(6, True), (7, True)])