pynccl.py 5.66 KB
Newer Older
1
from contextlib import contextmanager
2
from typing import Optional, Union
3
4
5
6

# ===================== import region =====================
import torch
import torch.distributed as dist
7
from torch.distributed import ProcessGroup, ReduceOp
8

9
10
11
from vllm.distributed.device_communicators.pynccl_wrapper import (
    NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
    ncclRedOpTypeEnum, ncclUniqueId)
12
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
13
14
15
from vllm.logger import init_logger

logger = init_logger(__name__)
16
17


18
class PyNcclCommunicator:
19
20
21

    def __init__(
        self,
22
23
        group: Optional[ProcessGroup] = None,
        device: Optional[Union[int, str, torch.device]] = None,
24
        library_path: Optional[str] = None,
25
    ):
26
27
28
29
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
30
            device: the device to bind the PyNcclCommunicator to. If None,
31
                it will be bind to f"cuda:{local_rank}".
32
33
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
34
35
36
37
38
39
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
        assert dist.is_initialized()
        group = get_cpu_world_group() if group is None else group
        assert dist.get_backend(group) != dist.Backend.NCCL, (
40
            "PyNcclCommunicator should be attached to a non-NCCL group.")
41
        self.group = group
42
        # note: this rank is the rank in the group
43
44
        self.rank = dist.get_rank(group)
        self.world_size = dist.get_world_size(group)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

        # if world_size == 1, no need to create communicator
        if self.world_size == 1:
            self.available = False
            self.disabled = True
            self.stream = None
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            self.stream = None
            return

        self.available = True
        self.disabled = False

        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

67
        if self.rank == 0:
68
69
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
70
        else:
71
72
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
73
        tensor = torch.ByteTensor(list(self.unique_id.internal))
74
75
76
        ranks = dist.get_process_group_ranks(group)
        # arg `src` in `broadcast` is the global rank
        dist.broadcast(tensor, src=ranks[0], group=group)
77
        byte_list = tensor.tolist()
78
79
        for i, byte in enumerate(byte_list):
            self.unique_id.internal[i] = byte
80
81
82
83
84
85
86
87
88
89
90
        if device is None:
            local_rank = get_local_rank()
            device = torch.device(f"cuda:{local_rank}")
        elif isinstance(device, int):
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
91
92
93
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
94
95
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
                self.world_size, self.unique_id, self.rank)
96
            self.stream = torch.cuda.Stream()
97

98
            # A small all_reduce for warmup.
99
100
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
101
            self.stream.synchronize()
102
            del data
103
104
105
106
107
108

        # by default it is disabled, e.g. in profiling models and prefill phase.
        # to use it, use under `with obj.change_state(enable=True)`, usually
        # when we are using CUDA graph.
        self.disabled = True

109
110
111
112
    def all_reduce(self,
                   tensor: torch.Tensor,
                   op: ReduceOp = ReduceOp.SUM,
                   stream=None):
113
114
        if self.disabled:
            return
115
116
117
118
119
120
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
121
122
        if stream is None:
            stream = self.stream
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
                                buffer_type(tensor.data_ptr()), tensor.numel(),
                                ncclDataTypeEnum.from_torch(tensor.dtype),
                                ncclRedOpTypeEnum.from_torch(op), self.comm,
                                cudaStream_t(stream.cuda_stream))

    @contextmanager
    def change_state(self,
                     enable: Optional[bool] = None,
                     stream: Optional[torch.cuda.Stream] = None):
        """
        A context manager to change the state of the communicator.
        """
        if enable is None:
            # guess a default value when not specified
            enable = self.available

        if stream is None:
            stream = self.stream

        old_disable = self.disabled
        old_stream = self.stream

        self.stream = stream
        self.disabled = not enable
        yield

        self.disabled = old_disable
        self.stream = old_stream