test_nccl_symm_mem_allreduce.py 3.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import random
import typing

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

import vllm.envs as envs
13
from tests.utils import ensure_current_vllm_config
14
from vllm.distributed import cleanup_dist_env_and_memory
15
16
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
17
from vllm.distributed.device_communicators.pynccl_allocator import (
18
19
20
21
22
23
24
25
    get_nccl_mem_pool,
    is_symmetric_memory_enabled,
)
from vllm.distributed.parallel_state import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
)
26
from vllm.platforms import current_platform
27
from vllm.utils.system_utils import update_environment_variables
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

torch.manual_seed(42)
random.seed(44)

test_size_elements = 4 * 1024 * 1024


def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
    monkeypatch = pytest.MonkeyPatch()
    with monkeypatch.context() as m:
        m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
        dtype = torch.bfloat16
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
        torch.set_default_device(device)
        torch.set_default_dtype(dtype)
44
45
46
47
48
49
50
51
52
        update_environment_variables(
            {
                "RANK": str(local_rank),
                "LOCAL_RANK": str(local_rank),
                "WORLD_SIZE": str(world_size),
                "MASTER_ADDR": "localhost",
                "MASTER_PORT": "12345",
            }
        )
53
54

        init_distributed_environment()
55
56
        with ensure_current_vllm_config():
            initialize_model_parallel(tensor_model_parallel_size=world_size)
57

58
59
60
        cuda_communicator = typing.cast(
            CudaCommunicator, get_tp_group().device_communicator
        )
61
62
        pynccl_comm = cuda_communicator.pynccl_comm
        if get_nccl_mem_pool() is None:
63
64
65
            pytest.skip(
                "NCCL allocator compilation failed (probably missing NCCL headers)."
            )
66
67
68
69
        if not is_symmetric_memory_enabled():
            pytest.skip("NCCL symmetric memory allreduce is disabled.")

        register_nccl_symmetric_ops(pynccl_comm)
70
        input = torch.randint(1, 23, (test_size_elements,), dtype=dtype, device=device)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        input_clone = input.clone()
        output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
        assert output is not None

        group = get_tp_group().device_group
        dist.all_reduce(input_clone, group=group)
        torch.testing.assert_close(output, input_clone, atol=2.5, rtol=0.1)


@pytest.mark.skipif(
    not current_platform.is_cuda(),
    reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
)
@pytest.mark.parametrize("world_size", [2])
85
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
86
87
88
89
90
91
92
93
94
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs to run the test.")

    # Enable SymmMemCommunicator
    monkeypatch.setenv("VLLM_USE_NCCL_SYMM_MEM", "1")
    monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
    monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")

95
    mp.spawn(nccl_symm_mem_allreduce_worker, args=(world_size,), nprocs=world_size)
96
    cleanup_dist_env_and_memory()