test_utils.py 1.07 KB
Newer Older
1
2
import ray

3
4
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment)
5
6
7
8
from vllm.utils import get_open_port


def init_test_distributed_environment(
9
10
    tp_size: int,
    pp_size: int,
11
12
    rank: int,
    distributed_init_port: str,
13
    local_rank: int = -1,
14
15
) -> None:
    distributed_init_method = f"tcp://localhost:{distributed_init_port}"
Woosuk Kwon's avatar
Woosuk Kwon committed
16
    init_distributed_environment(
17
        world_size=pp_size * tp_size,
18
        rank=rank,
19
20
        distributed_init_method=distributed_init_method,
        local_rank=local_rank)
21
    ensure_model_parallel_initialized(tp_size, pp_size)
22
23
24


def multi_process_tensor_parallel(
25
26
    tp_size: int,
    pp_size: int,
27
28
29
30
31
32
33
34
    test_target,
) -> None:
    # Using ray helps debugging the error when it failed
    # as compared to multiprocessing.
    ray.init()

    distributed_init_port = get_open_port()
    refs = []
35
    for rank in range(tp_size * pp_size):
36
        refs.append(
37
            test_target.remote(tp_size, pp_size, rank, distributed_init_port))
38
39
40
    ray.get(refs)

    ray.shutdown()