"vllm/vscode:/vscode.git/clone" did not exist on "8c87a9ad46dd8b972d4cd9c6cecb5b284c92f583"
test_collective_rpc.py 1.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
import pytest

from vllm import LLM

7
from ...utils import create_new_process_for_each_test
8
9
10
11


@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("backend", ["mp", "ray"])
12
@create_new_process_for_each_test()
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def test_collective_rpc(tp_size, backend):
    if tp_size == 1 and backend == "ray":
        pytest.skip("Skip duplicate test case")
    if tp_size == 1:
        backend = None

    # intentionally define the method and class in the test function,
    # to test if they can be serialized and sent to the workers
    def echo_rank(self):
        return self.rank

    from vllm.worker.worker import Worker

    class MyWorker(Worker):

        def echo_rank(self):
            return self.rank

31
    llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
32
33
34
35
36
37
38
              enforce_eager=True,
              load_format="dummy",
              tensor_parallel_size=tp_size,
              distributed_executor_backend=backend,
              worker_cls=MyWorker)
    for method in ["echo_rank", echo_rank]:
        assert llm.collective_rpc(method) == list(range(tp_size))