test_symm_mem_allreduce.py 5.12 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import queue
5
6
7
8
9
10
11
12
13
import random
import typing

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

import vllm.envs as envs
14
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
15
16
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
17
18
19
20
21
22
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
from vllm.distributed.parallel_state import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
)
23
24
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
25
from vllm.platforms import current_platform
26
from vllm.utils.system_utils import update_environment_variables
27
28
29
30

torch.manual_seed(42)
random.seed(44)

31
test_size_elements = 1024 * 1024
32
33


34
def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
35
    monkeypatch = pytest.MonkeyPatch()
36
    config = VllmConfig(parallel_config=ParallelConfig(tensor_parallel_size=world_size))
37
38

    with monkeypatch.context() as m, set_current_vllm_config(config):
39
40
41
42
43
44
        m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
        dtype = torch.bfloat16
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
        torch.set_default_device(device)
        torch.set_default_dtype(dtype)
45
46
47
48
49
50
51
52
53
        update_environment_variables(
            {
                "RANK": str(local_rank),
                "LOCAL_RANK": str(local_rank),
                "WORLD_SIZE": str(world_size),
                "MASTER_ADDR": "localhost",
                "MASTER_PORT": "12345",
            }
        )
54
55
56
57

        init_distributed_environment()
        initialize_model_parallel(tensor_model_parallel_size=world_size)

58
59
60
        cuda_communicator = typing.cast(
            CudaCommunicator, get_tp_group().device_communicator
        )
61
62
        symm_mem_comm = cuda_communicator.symm_mem_comm
        if symm_mem_comm is None or symm_mem_comm.disabled:
63
64
65
            # can't use skip under multiprocessing
            q.put("SymmMemCommunicator is not available or disabled.")
            return
66

67
68
69
        inp_direct_symm_mem = torch.randint(
            1, 23, (test_size_elements,), dtype=dtype, device=device
        )
70
        if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
71
            # can't use skip under multiprocessing
72
            q.put("SymmMemCommunicator isn't used for this world and input size.")
73
            return
74
75
76
77
78

        original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
        out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
        assert out_direct_symm_mem is not None

79
        group = get_tp_group().device_group
80
        dist.all_reduce(original_inp_direct_symm_mem, group=group)
81
82
83
        torch.testing.assert_close(
            out_direct_symm_mem, original_inp_direct_symm_mem, atol=2.5, rtol=0.1
        )
84
85

        # Test tensor_model_parallel_all_reduce which should use symm_mem
86
87
88
        inp_tensor_parallel = torch.randint(
            -23, 1, (test_size_elements,), dtype=dtype, device=device
        )
89
        original_inp_tensor_parallel = inp_tensor_parallel.clone()
90
        out_tensor_parallel = tensor_model_parallel_all_reduce(inp_tensor_parallel)
91
        dist.all_reduce(original_inp_tensor_parallel, group=group)
92
93
94
        torch.testing.assert_close(
            out_tensor_parallel, original_inp_tensor_parallel, atol=2.5, rtol=0.1
        )
95
96
97
98


@pytest.mark.skipif(
    not current_platform.is_cuda(),
99
100
    reason="SymmMemAllreduce is only available for CUDA platforms.",
)
101
102
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1])
103
104
105
106
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_symm_mem_allreduce(
    monkeypatch: pytest.MonkeyPatch, tp_size, pipeline_parallel_size
):
107
108
109
    world_size = tp_size * pipeline_parallel_size
    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs to run the test.")
110
111
    q = mp.get_context("spawn").Queue()
    mp.spawn(symm_mem_allreduce_worker, args=(world_size, q), nprocs=world_size)
112
113
114
115
116
117
118
119
    try:
        val = q.get(timeout=1)
    except queue.Empty:
        val = None
    finally:
        cleanup_dist_env_and_memory()
        if val is not None:
            pytest.skip(val)
120
121


122
123
@pytest.mark.skipif(
    not current_platform.is_cuda(),
124
125
126
    reason="SymmMemAllreduce is only available for CUDA platforms.",
)
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
127
128
129
130
131
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
    world_size = 4
    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs to run the test.")
    # Verify that the DataParallel runs without error
132
133
134
135
136
137
138
139
    engine_args = EngineArgs(
        model="distilbert/distilgpt2",
        enforce_eager=True,
        enable_prefix_caching=True,
        data_parallel_size=2,
        tensor_parallel_size=2,
        data_parallel_backend="mp",
    )
140
    LLMEngine.from_engine_args(engine_args)