pynccl.py 5.61 KB
Newer Older
1
from contextlib import contextmanager
2
from typing import Optional, Union
3
4
5
6

# ===================== import region =====================
import torch
import torch.distributed as dist
7
from torch.distributed import ProcessGroup, ReduceOp
8

9
10
11
from vllm.distributed.device_communicators.pynccl_wrapper import (
    NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
    ncclRedOpTypeEnum, ncclUniqueId)
12
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
13
14
15
from vllm.logger import init_logger

logger = init_logger(__name__)
16
17


18
class PyNcclCommunicator:
19
20
21

    def __init__(
        self,
22
23
        group: Optional[ProcessGroup] = None,
        device: Optional[Union[int, str, torch.device]] = None,
24
        library_path: Optional[str] = None,
25
    ):
26
27
28
29
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
30
            device: the device to bind the PyNcclCommunicator to. If None,
31
                it will be bind to f"cuda:{local_rank}".
32
33
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
34
35
36
37
38
39
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
        assert dist.is_initialized()
        group = get_cpu_world_group() if group is None else group
        assert dist.get_backend(group) != dist.Backend.NCCL, (
40
            "PyNcclCommunicator should be attached to a non-NCCL group.")
41
        self.group = group
42
        # note: this rank is the rank in the group
43
44
        self.rank = dist.get_rank(group)
        self.world_size = dist.get_world_size(group)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

        # if world_size == 1, no need to create communicator
        if self.world_size == 1:
            self.available = False
            self.disabled = True
            self.stream = None
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            self.stream = None
            return

        self.available = True
        self.disabled = False

        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

67
        if self.rank == 0:
68
69
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
70
        else:
71
72
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
73
        tensor = torch.ByteTensor(list(self.unique_id.internal))
74
75
76
        ranks = dist.get_process_group_ranks(group)
        # arg `src` in `broadcast` is the global rank
        dist.broadcast(tensor, src=ranks[0], group=group)
77
        byte_list = tensor.tolist()
78
79
        for i, byte in enumerate(byte_list):
            self.unique_id.internal[i] = byte
80
81
82
83
84
85
86
87
88
89
90
        if device is None:
            local_rank = get_local_rank()
            device = torch.device(f"cuda:{local_rank}")
        elif isinstance(device, int):
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
91
92
93
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
94
95
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
                self.world_size, self.unique_id, self.rank)
96
            self.stream = torch.cuda.Stream()
97

98
99
100
101
102
103
104
105
106
            # A small all_reduce for warmup.
            self.all_reduce(torch.zeros(1, device=device))
            self.stream.synchronize()

        # by default it is disabled, e.g. in profiling models and prefill phase.
        # to use it, use under `with obj.change_state(enable=True)`, usually
        # when we are using CUDA graph.
        self.disabled = True

107
108
109
110
    def all_reduce(self,
                   tensor: torch.Tensor,
                   op: ReduceOp = ReduceOp.SUM,
                   stream=None):
111
112
        if self.disabled:
            return
113
114
115
116
117
118
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
119
120
        if stream is None:
            stream = self.stream
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
                                buffer_type(tensor.data_ptr()), tensor.numel(),
                                ncclDataTypeEnum.from_torch(tensor.dtype),
                                ncclRedOpTypeEnum.from_torch(op), self.comm,
                                cudaStream_t(stream.cuda_stream))

    @contextmanager
    def change_state(self,
                     enable: Optional[bool] = None,
                     stream: Optional[torch.cuda.Stream] = None):
        """
        A context manager to change the state of the communicator.
        """
        if enable is None:
            # guess a default value when not specified
            enable = self.available

        if stream is None:
            stream = self.stream

        old_disable = self.disabled
        old_stream = self.stream

        self.stream = stream
        self.disabled = not enable
        yield

        self.disabled = old_disable
        self.stream = old_stream