pynccl.py 6.48 KB
Newer Older
1
from contextlib import contextmanager
2
from typing import Optional, Union
3
4
5
6

# ===================== import region =====================
import torch
import torch.distributed as dist
7
from torch.distributed import ProcessGroup, ReduceOp
8

9
10
11
from vllm.distributed.device_communicators.pynccl_wrapper import (
    NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
    ncclRedOpTypeEnum, ncclUniqueId)
12
13
14
from vllm.logger import init_logger

logger = init_logger(__name__)
15
16


17
class PyNcclCommunicator:
18
19
20

    def __init__(
        self,
21
22
        group: ProcessGroup,
        device: Union[int, str, torch.device],
23
        library_path: Optional[str] = None,
24
    ):
25
26
27
28
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
29
            device: the device to bind the PyNcclCommunicator to. If None,
30
                it will be bind to f"cuda:{local_rank}".
31
32
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
33
34
35
36
37
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
        assert dist.is_initialized()
        assert dist.get_backend(group) != dist.Backend.NCCL, (
38
            "PyNcclCommunicator should be attached to a non-NCCL group.")
39
        self.group = group
40
        # note: this rank is the rank in the group
41
42
        self.rank = dist.get_rank(group)
        self.world_size = dist.get_world_size(group)
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

        # if world_size == 1, no need to create communicator
        if self.world_size == 1:
            self.available = False
            self.disabled = True
            self.stream = None
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            self.stream = None
            return

        self.available = True
        self.disabled = False

        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

65
        if self.rank == 0:
66
67
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
68
        else:
69
70
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
71
        tensor = torch.ByteTensor(list(self.unique_id.internal))
72
73
74
        ranks = dist.get_process_group_ranks(group)
        # arg `src` in `broadcast` is the global rank
        dist.broadcast(tensor, src=ranks[0], group=group)
75
        byte_list = tensor.tolist()
76
77
        for i, byte in enumerate(byte_list):
            self.unique_id.internal[i] = byte
78
        if isinstance(device, int):
79
80
81
82
83
84
85
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
86
87
88
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
89
90
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
                self.world_size, self.unique_id, self.rank)
91
            self.stream = torch.cuda.Stream()
92

93
            # A small all_reduce for warmup.
94
95
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
96
            self.stream.synchronize()
97
            del data
98
99
100
101
102
103

        # by default it is disabled, e.g. in profiling models and prefill phase.
        # to use it, use under `with obj.change_state(enable=True)`, usually
        # when we are using CUDA graph.
        self.disabled = True

104
105
106
107
    def all_reduce(self,
                   tensor: torch.Tensor,
                   op: ReduceOp = ReduceOp.SUM,
                   stream=None):
108
109
        if self.disabled:
            return
110
111
112
113
114
115
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
116
117
        if stream is None:
            stream = self.stream
118
119
120
121
122
123
        self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
                                buffer_type(tensor.data_ptr()), tensor.numel(),
                                ncclDataTypeEnum.from_torch(tensor.dtype),
                                ncclRedOpTypeEnum.from_torch(op), self.comm,
                                cudaStream_t(stream.cuda_stream))

124
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
125
126
127
128
129
130
131
132
133
134
135
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
            stream = self.stream
        self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), dst,
                           self.comm, cudaStream_t(stream.cuda_stream))

136
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
137
138
139
140
141
142
143
144
145
146
147
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
            stream = self.stream
        self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), src,
                           self.comm, cudaStream_t(stream.cuda_stream))

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    @contextmanager
    def change_state(self,
                     enable: Optional[bool] = None,
                     stream: Optional[torch.cuda.Stream] = None):
        """
        A context manager to change the state of the communicator.
        """
        if enable is None:
            # guess a default value when not specified
            enable = self.available

        if stream is None:
            stream = self.stream

        old_disable = self.disabled
        old_stream = self.stream

        self.stream = stream
        self.disabled = not enable
        yield

        self.disabled = old_disable
        self.stream = old_stream