test_symm_mem_allreduce.py 4.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import random
import typing

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

import vllm.envs as envs
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.cuda_communicator import (
    CudaCommunicator)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
                                             get_tp_group,
                                             init_distributed_environment,
                                             initialize_model_parallel)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables

torch.manual_seed(42)
random.seed(44)

test_size_elements = 4 * 1024 * 1024


def symm_mem_allreduce_worker(local_rank: int, world_size: int):
    monkeypatch = pytest.MonkeyPatch()
    with monkeypatch.context() as m:
        m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
        dtype = torch.bfloat16
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
        torch.set_default_device(device)
        torch.set_default_dtype(dtype)
        update_environment_variables({
            'RANK': str(local_rank),
            'LOCAL_RANK': str(local_rank),
            'WORLD_SIZE': str(world_size),
            'MASTER_ADDR': 'localhost',
            'MASTER_PORT': '12345',
        })

        init_distributed_environment()
        initialize_model_parallel(tensor_model_parallel_size=world_size)

        cuda_communicator = typing.cast(CudaCommunicator,
                                        get_tp_group().device_communicator)
        symm_mem_comm = cuda_communicator.symm_mem_comm
        if symm_mem_comm is None or symm_mem_comm.disabled:
            pytest.skip("SymmMemCommunicator is not available or disabled.")

        inp_direct_symm_mem = torch.randint(1,
                                            23, (test_size_elements, ),
                                            dtype=dtype,
                                            device=device)
        if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
            pytest.skip(
                "SymmMemCommunicator isn't used for this world and input size."
            )

        original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
        out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
        assert out_direct_symm_mem is not None

        group = get_tensor_model_parallel_group().device_group
        dist.all_reduce(original_inp_direct_symm_mem, group=group)
        torch.testing.assert_close(out_direct_symm_mem,
                                   original_inp_direct_symm_mem,
                                   atol=2.5,
                                   rtol=0.1)

        # Test tensor_model_parallel_all_reduce which should use symm_mem
        inp_tensor_parallel = torch.randint(-23,
                                            1, (test_size_elements, ),
                                            dtype=dtype,
                                            device=device)
        original_inp_tensor_parallel = inp_tensor_parallel.clone()
        out_tensor_parallel = tensor_model_parallel_all_reduce(
            inp_tensor_parallel)
        dist.all_reduce(original_inp_tensor_parallel, group=group)
        torch.testing.assert_close(out_tensor_parallel,
                                   original_inp_tensor_parallel,
                                   atol=2.5,
                                   rtol=0.1)


@pytest.mark.skipif(
    not current_platform.is_cuda(),
    reason="SymmMemAllreduce is only available for CUDA platforms.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
                    reason="Only test on CUDA")
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
                            pipeline_parallel_size):
    world_size = tp_size * pipeline_parallel_size
    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs to run the test.")

    # Enable SymmMemCommunicator
    monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")

    mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
    cleanup_dist_env_and_memory()