symm_mem.py 4.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.distributed.device_communicators.all_reduce_utils import (
    SYMM_MEM_ALL_REDUCE_MAX_SIZES)
from vllm.logger import init_logger
from vllm.platforms import current_platform

try:
    import torch.distributed._symmetric_memory as torch_symm_mem

    symm_mem_available = True
except ImportError:
    symm_mem_available = False

logger = init_logger(__name__)


class SymmMemCommunicator:
    _WORLD_SIZES_MULTIMEM = {
        "9.0": [4, 6, 8],
        "10.0": [6, 8],
    }

30
31
32
33
34
35
36
    def __init__(
            self,
            group: ProcessGroup,
            device: Union[int, str, torch.device],
            # add options for testing
            force_multimem: Optional[bool] = None,
            max_size_override: Optional[int] = None):
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        self.disabled = True

        if not symm_mem_available:
            return

        if not current_platform.is_cuda():
            logger.warning("SymmMemCommunicator: symmetric "
                           "memory is not available.")
            return
        if isinstance(device, int):
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        torch.cuda.set_device(device)
        self.dtype = torch.bfloat16
        self.device = device
        self.group = group
        self.world_size = dist.get_world_size(self.group)
        self.device_capability = current_platform.get_device_capability(
        ).as_version_str()
        if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
            logger.warning(
                "SymmMemCommunicator: Device capability %s not supported, "
                "communicator is not available.",
                self.device_capability,
            )
            return
        if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
                self.device_capability]:
            logger.warning(
                "SymmMemCommunicator: World size %d not supported, "
                "communicator is not available.",
                self.world_size,
            )
            return
72
73
74
75
76
77
78
79
80
81
82
        # Use override max_size if provided, otherwise use default
        if max_size_override is not None:
            self.max_size = max_size_override
            logger.info(
                "SymmMemCommunicator: Using override max_size: %s bytes",
                self.max_size,
            )
        else:
            self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[
                self.device_capability][self.world_size]

83
84
85
86
87
88
89
90
91
92
        self.buffer = torch_symm_mem.empty(
            self.max_size // self.dtype.itemsize,
            device=self.device,
            dtype=self.dtype,
        )
        handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
        if handle.multicast_ptr == 0:
            logger.warning("SymmMemCommunicator: symmetric memory "
                           "multicast operations are not supported.")
            return
93
        self.force_multimem = force_multimem
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        self.disabled = False

    def should_use_symm_mem(self, inp: torch.Tensor):
        if self.disabled:
            return False
        if inp.dtype != self.dtype:
            return False
        inp_size = inp.numel() * inp.element_size()
        if inp_size % 4 != 0:
            return False
        return inp_size < self.max_size

    def all_reduce(
            self,
            inp: torch.Tensor,
            *,
            out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
        if not self.should_use_symm_mem(inp):
            return None
        if out is None:
            out = torch.empty_like(inp)
        self.buffer[:inp.numel()].copy_(inp.view(-1))
116
117
118
119
120
121
122
123
124
125
126
127

        # Determine which algorithm to use
        use_multimem = False
        if self.force_multimem is not None:
            # Test override: use forced setting
            use_multimem = self.force_multimem
        else:
            # Normal logic: use multimem for supported world sizes
            use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[
                self.device_capability]

        if use_multimem:
128
129
130
131
132
133
134
135
136
            torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
                                                    "sum",
                                                    self.group.group_name)
        else:
            torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
                                                    "sum",
                                                    self.group.group_name)
        out.copy_(self.buffer[:inp.numel()].view(out.shape))
        return out