test_utils.py 1.23 KB
Newer Older
1
2
import ray

3
4
from vllm.model_executor.parallel_utils.parallel_state import (
    ensure_model_parallel_initialized, init_distributed_environment)
5
6
7
8
9
10
11
12
from vllm.utils import get_open_port


def init_test_distributed_environment(
    pipeline_parallel_size: int,
    tensor_parallel_size: int,
    rank: int,
    distributed_init_port: str,
13
    local_rank: int = -1,
14
15
) -> None:
    distributed_init_method = f"tcp://localhost:{distributed_init_port}"
Woosuk Kwon's avatar
Woosuk Kwon committed
16
    init_distributed_environment(
17
18
        world_size=pipeline_parallel_size * tensor_parallel_size,
        rank=rank,
19
20
        distributed_init_method=distributed_init_method,
        local_rank=local_rank)
21
22
    ensure_model_parallel_initialized(tensor_parallel_size,
                                      pipeline_parallel_size)
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41


def multi_process_tensor_parallel(
    tensor_parallel_size: int,
    test_target,
) -> None:
    # Using ray helps debugging the error when it failed
    # as compared to multiprocessing.
    ray.init()

    distributed_init_port = get_open_port()
    refs = []
    for rank in range(tensor_parallel_size):
        refs.append(
            test_target.remote(tensor_parallel_size, rank,
                               distributed_init_port))
    ray.get(refs)

    ray.shutdown()