parallel_state.py 46.3 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
# Copyright 2023 The vLLM team.
2
3
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
4
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
5
6
7
8
9
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
10
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
11
12
13
14
15
16
17
18
19
20
21
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
22
import contextlib
23
import gc
24
import pickle
25
import weakref
26
27
28
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
29
from multiprocessing import shared_memory
30
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
31
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
32
33

import torch
34
import torch.distributed
35
from torch.distributed import Backend, ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
36

37
import vllm.envs as envs
38
from vllm.logger import init_logger
39
from vllm.platforms import current_platform
40
from vllm.utils import supports_custom_op
41
42


43
44
45
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
46

47

48
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
49

50
51

def _split_tensor_dict(
52
53
    tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
54
55
56
57
58
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
59
    metadata_list: List[Tuple[str, Any]] = []
60
    tensor_list: List[torch.Tensor] = []
61
62
63
64
65
66
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
67
            device = value.device.type
68
            metadata_list.append(
69
                (key, TensorMetadata(device, value.dtype, value.size())))
70
71
            tensor_list.append(value)
        else:
72
            metadata_list.append((key, value))
73
74
75
    return metadata_list, tensor_list


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
_group_name_counter: Dict[str, int] = {}


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}


def _register_group(group: "GroupCoordinator") -> None:
    # looks like Python 3.8 does not understand `ReferenceType`
    _groups[group.unique_name] = weakref.ref(group)  # type: ignore


100
if supports_custom_op():
101

102
103
104
105
106
107
108
    @torch.library.custom_op("vllm::inplace_all_reduce",
                             mutates_args=["tensor"])
    def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
        assert group_name in _groups, f"Group {group_name} is not found."
        group = _groups[group_name]()
        if group is None:
            raise ValueError(f"Group {group_name} is destroyed.")
109
        group._all_reduce_in_place(tensor)
110

111
112
113
    @inplace_all_reduce.register_fake
    def _(tensor: torch.Tensor, group_name: str) -> None:
        return
114

115
116
117
118
119
120
121
    @torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
    def outplace_all_reduce(tensor: torch.Tensor,
                            group_name: str) -> torch.Tensor:
        assert group_name in _groups, f"Group {group_name} is not found."
        group = _groups[group_name]()
        if group is None:
            raise ValueError(f"Group {group_name} is destroyed.")
122
        return group._all_reduce_out_place(tensor)
123

124
125
126
    @outplace_all_reduce.register_fake
    def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
        return torch.empty_like(tensor)
127
128


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
        the processes in the group. It can route the communication to
        a specific implementation (e.g. switch allreduce implementation
        based on the tensor size and cuda graph mode).
    """

    # available attributes:
    rank: int  # global rank
    ranks: List[int]  # global ranks in the group
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
    use_pynccl: bool  # a hint of whether to use PyNccl
    use_custom_allreduce: bool  # a hint of whether to use CustomAllreduce
    # communicators are only created for world size > 1
    pynccl_comm: Optional[Any]  # PyNccl communicator
    ca_comm: Optional[Any]  # Custom allreduce communicator
160
    mq_broadcaster: Optional[Any]  # shared memory broadcaster
161
162
163
164
165
166
167
168

    def __init__(
        self,
        group_ranks: List[List[int]],
        local_rank: int,
        torch_distributed_backend: Union[str, Backend],
        use_pynccl: bool,
        use_custom_allreduce: bool,
169
        use_tpu_communicator: bool,
170
        use_message_queue_broadcaster: bool = False,
171
        group_name: Optional[str] = None,
172
    ):
173
174
175
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
        self.device_group = None
        self.cpu_group = None

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
                ranks, backend=torch_distributed_backend)
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
                self.device_group = device_group
                self.cpu_group = cpu_group

        assert self.cpu_group is not None
        assert self.device_group is not None

198
        if current_platform.is_cuda_alike():
199
200
201
202
203
204
            self.device = torch.device(f"cuda:{local_rank}")
        else:
            self.device = torch.device("cpu")

        self.use_pynccl = use_pynccl
        self.use_custom_allreduce = use_custom_allreduce
205
        self.use_tpu_communicator = use_tpu_communicator
206
207
208
209
210
211
212

        # lazy import to avoid documentation build error
        from vllm.distributed.device_communicators.custom_all_reduce import (
            CustomAllreduce)
        from vllm.distributed.device_communicators.pynccl import (
            PyNcclCommunicator)

213
        self.pynccl_comm: Optional[PyNcclCommunicator] = None
214
215
216
217
218
219
        if use_pynccl and self.world_size > 1:
            self.pynccl_comm = PyNcclCommunicator(
                group=self.cpu_group,
                device=self.device,
            )

220
        self.ca_comm: Optional[CustomAllreduce] = None
221
222
223
224
225
226
227
        if use_custom_allreduce and self.world_size > 1:
            # Initialize a custom fast all-reduce implementation.
            self.ca_comm = CustomAllreduce(
                group=self.cpu_group,
                device=self.device,
            )

228
229
        from vllm.distributed.device_communicators.tpu_communicator import (
            TpuCommunicator)
230
        self.tpu_communicator: Optional[TpuCommunicator] = None
231
232
233
        if use_tpu_communicator and self.world_size > 1:
            self.tpu_communicator = TpuCommunicator(group=self.cpu_group)

234
        from vllm.distributed.device_communicators.shm_broadcast import (
235
236
237
238
            MessageQueue)
        self.mq_broadcaster: Optional[MessageQueue] = None
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
239
                self.cpu_group, 1 << 22, 6)
240

241
242
243
244
245
246
247
248
249
250
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

251
252
253
254
255
256
257
258
259
260
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
    def graph_capture(
            self, graph_capture_context: Optional[GraphCaptureContext] = None):
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

        ca_comm = self.ca_comm
        maybe_ca_context = nullcontext(
        ) if ca_comm is None else ca_comm.capture()
287
288
289
290
291
292
293

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        with torch.cuda.stream(stream), maybe_ca_context:
            # In graph mode, we have to be very careful about the collective
            # operations. The current status is:
            #     allreduce \ Mode   |  Eager  |  Graph  |
            # --------------------------------------------
            # custom allreduce       | enabled | enabled |
            # PyNccl                 | disabled| enabled |
            # torch.distributed      | enabled | disabled|
            #
            # Note that custom allreduce will have a runtime check, if the
            #  tensor size is too large, it will fallback to the next
            #  available option.
            # In summary: When using CUDA graph, we use
            #  either custom all-reduce kernel or pynccl. When not using
            #  CUDA graph, we use either custom all-reduce kernel or
            #  PyTorch NCCL. We always prioritize using custom all-reduce
            #  kernel but fall back to PyTorch or pynccl if it is
            #  disabled or not supported.
            pynccl_comm = self.pynccl_comm
            maybe_pynccl_context: Any
            if not pynccl_comm:
                maybe_pynccl_context = nullcontext()
            else:
                maybe_pynccl_context = pynccl_comm.change_state(
                    enable=True, stream=torch.cuda.current_stream())
            with maybe_pynccl_context:
                yield graph_capture_context

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
        a new tensor in the same op. So we need to figure out if the op is
        in-place or out-of-place ahead of time.
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

341
        if not supports_custom_op():
342
343
            self._all_reduce_in_place(input_)
            return input_
344

345
346
347
        if self.tpu_communicator is not None and \
            not self.tpu_communicator.disabled:
            # TPU handles Dynamo with its own logic.
348
            return self.tpu_communicator.all_reduce(input_)
349

350
351
352
        if self.ca_comm is not None and \
            not self.ca_comm.disabled and \
                self.ca_comm.should_custom_ar(input_):
353
354
355
356
357
358
359
            return torch.ops.vllm.outplace_all_reduce(
                input_, group_name=self.unique_name)
        else:
            torch.ops.vllm.inplace_all_reduce(input_,
                                              group_name=self.unique_name)
            return input_

360
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
361
        ca_comm = self.ca_comm
362
363
364
365
366
        assert ca_comm is not None
        assert not ca_comm.disabled
        out = ca_comm.custom_all_reduce(input_)
        assert out is not None
        return out
367

368
    def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
369
370
371
        pynccl_comm = self.pynccl_comm
        if (pynccl_comm is not None and not pynccl_comm.disabled):
            pynccl_comm.all_reduce(input_)
372
373
374
        elif input_.is_cpu:
            import intel_extension_for_pytorch as ipex
            ipex.distributed.all_reduce(input_, group=self.device_group)
375
376
377
378
379
380
381
382
383
384
        else:
            torch.distributed.all_reduce(input_, group=self.device_group)

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
385
386
387
388
389
390

        # For TPUs, use TPU communicator.
        tpu_comm = self.tpu_communicator
        if tpu_comm is not None and not tpu_comm.disabled:
            return tpu_comm.all_gather(input_, dim)

391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        input_size = input_.size()
        # Allocate output tensor.
        output_tensor = torch.empty((world_size, ) + input_size,
                                    dtype=input_.dtype,
                                    device=input_.device)
        # All-gather.
        torch.distributed.all_gather_into_tensor(output_tensor,
                                                 input_,
                                                 group=self.device_group)
        # Reshape
        output_tensor = output_tensor.movedim(0, dim)
        output_tensor = output_tensor.reshape(input_size[:dim] +
                                              (world_size *
                                               input_size[dim], ) +
                                              input_size[dim + 1:])
        return output_tensor

    def gather(self,
               input_: torch.Tensor,
               dst: int = 0,
414
               dim: int = -1) -> Optional[torch.Tensor]:
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        # Allocate output tensor.
        if self.rank_in_group == dst:
            gather_list = [torch.empty_like(input_) for _ in range(world_size)]
        else:
            gather_list = None
        # Gather.
        torch.distributed.gather(input_,
                                 gather_list,
                                 dst=self.ranks[dst],
                                 group=self.device_group)
        if self.rank_in_group == dst:
            output_tensor = torch.cat(gather_list, dim=dim)
        else:
            output_tensor = None
        return output_tensor

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
        torch.distributed.broadcast(input_,
                                    src=self.ranks[src],
                                    group=self.device_group)
        return input_

460
461
462
463
464
465
466
467
468
    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
469
470
471
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
472
473
474
475
476
477
478
479
480
481
482
483
        if self.rank_in_group == src:
            torch.distributed.broadcast_object_list([obj],
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return obj
        else:
            recv = [None]
            torch.distributed.broadcast_object_list(recv,
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return recv[0]

484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
    def broadcast_object_list(self,
                              obj_list: List[Any],
                              src: int = 0,
                              group: Optional[ProcessGroup] = None):
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
        torch.distributed.broadcast_object_list(obj_list,
                                                src=self.ranks[src],
                                                group=self.device_group)
        return obj_list

502
503
504
505
506
507
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

508
        assert dst != self.rank_in_group, (
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
            "Invalid destination rank. Destination rank is the same "
            "as the current rank.")

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

        size_tensor = torch.tensor([object_tensor.numel()],
                                   dtype=torch.long,
                                   device="cpu")

        # Send object size

        torch.distributed.send(size_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        # Send object
        torch.distributed.send(object_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

538
        assert src != self.rank_in_group, (
539
540
541
542
543
544
545
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
        rank_size = torch.distributed.recv(size_tensor,
546
                                           src=self.ranks[src],
547
548
549
550
551
552
553
554
555
                                           group=self.cpu_group)

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
            device="cpu")

        rank_object = torch.distributed.recv(object_tensor,
556
                                             src=self.ranks[src],
557
558
559
560
561
562
563
564
565
                                             group=self.cpu_group)

        assert rank_object == rank_size, (
            "Received object sender rank does not match the size sender rank.")

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

566
567
    def broadcast_tensor_dict(
        self,
568
        tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
569
570
571
        src: int = 0,
        group: Optional[ProcessGroup] = None,
        metadata_group: Optional[ProcessGroup] = None
572
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
573
574
575
576
577
578
579
580
581
582
583
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if (not torch.distributed.is_initialized() or self.world_size == 1):
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

584
585
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
586
587
588
589
590
591
592
593
            metadata_list: List[Tuple[Any, Any]] = []
            assert isinstance(
                tensor_dict,
                dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
594
            self.broadcast_object(metadata_list, src=src)
595
596
597
598
599
600
601
602
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
603
                                                         src=self.ranks[src],
604
605
606
607
608
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
609
                                                         src=self.ranks[src],
610
611
612
613
614
615
616
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
617
            metadata_list = self.broadcast_object(None, src=src)
618
619
            tensor_dict = {}
            async_handles = []
620
            for key, value in metadata_list:
621
622
623
624
625
626
                if isinstance(value, TensorMetadata):
                    tensor = torch.empty(value.size,
                                         dtype=value.dtype,
                                         device=value.device)
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
627
                        tensor_dict[key] = tensor
628
629
630
631
632
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
633
                            src=self.ranks[src],
634
635
636
637
                            group=metadata_group,
                            async_op=True)
                    else:
                        # use group for GPU tensors
638
639
640
641
642
                        handle = torch.distributed.broadcast(
                            tensor,
                            src=self.ranks[src],
                            group=group,
                            async_op=True)
643
                    async_handles.append(handle)
644
                    tensor_dict[key] = tensor
645
                else:
646
                    tensor_dict[key] = value
647
648
649
650
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

651
652
    def send_tensor_dict(
        self,
653
        tensor_dict: Dict[str, Union[torch.Tensor, Any]],
654
655
        dst: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
656
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
657
658
659
660
661
662
663
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict

664
665
666
667
668
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

669
670
671
672
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
673
            dst = (self.rank_in_group + 1) % self.world_size
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

        metadata_list: List[Tuple[Any, Any]] = []
        assert isinstance(
            tensor_dict,
            dict), f"Expecting a dictionary, got {type(tensor_dict)}"
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
        for tensor in tensor_list:
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
689
690
691
692
693
694

            # send-allgather: send only a slice, then do allgather.
            if (all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0):
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

695
696
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
697
698
699
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=metadata_group)
700
701
            else:
                # use group for GPU tensors
702
703
704
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=group)
705
706
707
708
        return None

    def recv_tensor_dict(
        self,
709
710
        src: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
711
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
712
713
714
715
716
717
718
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None

719
720
721
722
723
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

724
725
726
727
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
728
            src = (self.rank_in_group - 1) % self.world_size
729
730
731
        assert src < self.world_size, f"Invalid src rank ({src})"

        recv_metadata_list = self.recv_object(src=src)
732
        tensor_dict: Dict[str, Any] = {}
733
734
735
736
737
738
739
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device=value.device)
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
740
                    tensor_dict[key] = tensor
741
                    continue
742
743
744
745
746
747
748
749
750
751

                # send-allgather: send only a slice, then do allgather.
                use_all_gather = (all_gather_group is not None
                                  and tensor.numel() % all_gather_size == 0)

                if use_all_gather:
                    orig_shape = tensor.shape
                    tensor = tensor.reshape(all_gather_size,
                                            -1)[all_gather_rank]

752
753
754
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    torch.distributed.recv(tensor,
755
                                           src=self.ranks[src],
756
757
758
                                           group=metadata_group)
                else:
                    # use group for GPU tensors
759
760
761
                    torch.distributed.recv(tensor,
                                           src=self.ranks[src],
                                           group=group)
762
763
764
765
766
767
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
                        tensor, dim=0)
                    tensor = tensor.reshape(orig_shape)

768
                tensor_dict[key] = tensor
769
            else:
770
                tensor_dict[key] = value
771
772
        return tensor_dict

773
774
775
776
777
778
779
780
781
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

782
783
784
785
    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
        """Sends a tensor to the destination rank in a non-blocking way"""
        """NOTE: `dst` is the local rank of the destination rank."""
        if dst is None:
786
            dst = (self.rank_in_group + 1) % self.world_size
787
788
789
790
791
792
793
794
795
796
797

        pynccl_comm = self.pynccl_comm
        if pynccl_comm is not None and not pynccl_comm.disabled:
            pynccl_comm.send(tensor, dst)
        else:
            torch.distributed.send(tensor, self.ranks[dst], self.device_group)

    def recv(self,
             size: torch.Size,
             dtype: torch.dtype,
             src: Optional[int] = None) -> torch.Tensor:
798
799
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
800
        if src is None:
801
            src = (self.rank_in_group - 1) % self.world_size
802
803
804
805
806
807
808
809
810

        tensor = torch.empty(size, dtype=dtype, device=self.device)
        pynccl_comm = self.pynccl_comm
        if pynccl_comm is not None and not pynccl_comm.disabled:
            pynccl_comm.recv(tensor, src)
        else:
            torch.distributed.recv(tensor, self.ranks[src], self.device_group)
        return tensor

811
812
813
814
815
816
817
818
819
820
821
    def destroy(self):
        if self.device_group is not None:
            torch.distributed.destroy_process_group(self.device_group)
            self.device_group = None
        if self.cpu_group is not None:
            torch.distributed.destroy_process_group(self.cpu_group)
            self.cpu_group = None
        if self.pynccl_comm is not None:
            self.pynccl_comm = None
        if self.ca_comm is not None:
            self.ca_comm = None
822
823
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
824
825
826
827
828
829
830
831
832
833


_WORLD: Optional[GroupCoordinator] = None


def get_world_group() -> GroupCoordinator:
    assert _WORLD is not None, ("world group is not initialized")
    return _WORLD


834
835
836
837
838
839
840
841
def init_world_group(ranks: List[int], local_rank: int,
                     backend: str) -> GroupCoordinator:
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
        use_pynccl=False,
        use_custom_allreduce=False,
842
        use_tpu_communicator=False,
843
        group_name="world",
844
845
846
    )


847
def init_model_parallel_group(
848
849
850
851
852
    group_ranks: List[List[int]],
    local_rank: int,
    backend: str,
    use_custom_allreduce: Optional[bool] = None,
    use_message_queue_broadcaster: bool = False,
853
    group_name: Optional[str] = None,
854
) -> GroupCoordinator:
855
856
    if use_custom_allreduce is None:
        use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
857
858
859
860
861
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
        use_pynccl=True,
862
        use_custom_allreduce=use_custom_allreduce,
863
        use_tpu_communicator=True,
864
        use_message_queue_broadcaster=use_message_queue_broadcaster,
865
        group_name=group_name,
866
867
868
    )


869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
_TP: Optional[GroupCoordinator] = None


def get_tp_group() -> GroupCoordinator:
    assert _TP is not None, ("tensor model parallel group is not initialized")
    return _TP


# kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group

_PP: Optional[GroupCoordinator] = None


def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, (
        "pipeline model parallel group is not initialized")
    return _PP
887
888


889
890
# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group
891
892


893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
@contextmanager
def graph_capture():
    """
    `graph_capture` is a context manager which should surround the code that
    is capturing the CUDA graph. Its main purpose is to ensure that the
    some operations will be run after the graph is captured, before the graph
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
    with get_tp_group().graph_capture() as context, get_pp_group(
    ).graph_capture(context):
        yield context

912

913
logger = init_logger(__name__)
914

915
_ENABLE_CUSTOM_ALL_REDUCE = True
916
917


918
919
920
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
921

Zhuohan Li's avatar
Zhuohan Li committed
922

923
def init_distributed_environment(
924
925
926
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
927
928
929
    local_rank: int = -1,
    backend: str = "nccl",
):
930
931
932
933
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
934
935
936
937
938
939
940
941
942
943
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
944
945
946
947
948
949
950
951
952
953
954
955
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
        if distributed_init_method == "env://":
            local_rank = envs.LOCAL_RANK
        else:
            local_rank = rank
    global _WORLD
    if _WORLD is None:
956
        ranks = list(range(torch.distributed.get_world_size()))
957
        _WORLD = init_world_group(ranks, local_rank, backend)
958
959
960
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
            "world group already initialized with a different world size")
961
962


Zhuohan Li's avatar
Zhuohan Li committed
963
964
965
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
966
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
967
968
) -> None:
    """
969
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
970
971

    Arguments:
972
973
974
975
976
977
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
978
979
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
980
981
982
983
984
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
985
986
987
988
989
990
991
992
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
993
994
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
995

996
997
    if (world_size !=
            tensor_model_parallel_size * pipeline_model_parallel_size):
Zhuohan Li's avatar
Zhuohan Li committed
998
        raise RuntimeError(
999
1000
1001
1002
            f"world_size ({world_size}) is not equal to "
            f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
            f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

1003
    # Build the tensor model-parallel groups.
1004
1005
    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
1006
1007
1008
    global _TP
    assert _TP is None, ("tensor model parallel group is already initialized")
    group_ranks = []
Zhuohan Li's avatar
Zhuohan Li committed
1009
    for i in range(num_tensor_model_parallel_groups):
1010
1011
1012
        ranks = list(
            range(i * tensor_model_parallel_size,
                  (i + 1) * tensor_model_parallel_size))
1013
        group_ranks.append(ranks)
1014
1015

    # message queue broadcaster is only used in tensor model parallel group
1016
    _TP = init_model_parallel_group(group_ranks,
1017
1018
                                    get_world_group().local_rank,
                                    backend,
1019
1020
                                    use_message_queue_broadcaster=True,
                                    group_name="tp")
1021

1022
    # Build the pipeline model-parallel groups.
1023
1024
1025
1026
    num_pipeline_model_parallel_groups: int = (world_size //
                                               pipeline_model_parallel_size)
    global _PP
    assert _PP is None, (
1027
        "pipeline model parallel group is already initialized")
1028
    group_ranks = []
Zhuohan Li's avatar
Zhuohan Li committed
1029
    for i in range(num_pipeline_model_parallel_groups):
1030
        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
1031
        group_ranks.append(ranks)
1032
    # pipeline parallel does not need custom allreduce
1033
    _PP = init_model_parallel_group(group_ranks,
1034
1035
                                    get_world_group().local_rank,
                                    backend,
1036
1037
                                    use_custom_allreduce=False,
                                    group_name="pp")
1038

Zhuohan Li's avatar
Zhuohan Li committed
1039

1040
1041
1042
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1043
    backend: Optional[str] = None,
1044
1045
1046
1047
1048
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1049
1050
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
1051
1052
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
1053
                                  pipeline_model_parallel_size, backend)
1054
1055
1056
1057
1058
1059
1060
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
1061
1062
    pp_world_size = get_pp_group().world_size
    assert (pp_world_size == pipeline_model_parallel_size), (
1063
        "pipeline parallel group already initialized, but of unexpected size: "
1064
        f"{pp_world_size=} vs. "
1065
1066
1067
        f"{pipeline_model_parallel_size=}")


Zhuohan Li's avatar
Zhuohan Li committed
1068
def model_parallel_is_initialized():
1069
    """Check if tensor and pipeline parallel groups are initialized."""
1070
    return (_TP is not None and _PP is not None)
1071
1072


1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1101
1102
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1103
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1104
1105
1106
1107


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1108
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1109
1110
1111


def destroy_model_parallel():
1112
    """Set the groups to none and destroy them."""
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
    global _TP
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None


def destroy_distributed_environment():
    global _WORLD
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1131
1132


1133
1134
1135
1136
1137
1138
1139
1140
1141
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    if shutdown_ray:
        import ray  # Lazy import Ray
        ray.shutdown()
    gc.collect()
1142
    if not current_platform.is_cpu():
1143
1144
1145
        torch.cuda.empty_cache()


1146
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
1147
    """
1148
1149
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1150
1151
1152
1153
    memory system (shared access to shared memory).
    """
    assert torch.distributed.get_backend(
        pg) != torch.distributed.Backend.NCCL, (
1154
            "in_the_same_node_as should be tested with a non-NCCL group.")
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
    # local rank inside the group
    rank = torch.distributed.get_rank(group=pg)
    world_size = torch.distributed.get_world_size(group=pg)

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    # global ranks of the processes in the group
    ranks = torch.distributed.get_process_group_ranks(pg)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1170
            if rank == source_rank:
1171
1172
1173
1174
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
                shm.buf[:len(magic_message)] = magic_message
                torch.distributed.broadcast_object_list([shm.name],
1175
                                                        src=ranks[source_rank],
1176
                                                        group=pg)
1177
                is_in_the_same_node[rank] = 1
1178
1179
1180
1181
            else:
                # try to open the shared memory segment
                recv = [None]
                torch.distributed.broadcast_object_list(recv,
1182
                                                        src=ranks[source_rank],
1183
1184
                                                        group=pg)
                name = recv[0]
1185
1186
1187
1188
1189
1190
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
                with patch("multiprocessing.resource_tracker.register",
                           lambda *args, **kwargs: None):
                    shm = shared_memory.SharedMemory(name=name)
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
                if shm.buf[:len(magic_message)] == magic_message:
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

    torch.distributed.barrier(group=pg)

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1203
        if rank == source_rank and shm:
1204
            shm.unlink()
1205
1206
    torch.distributed.all_reduce(is_in_the_same_node, group=pg)

1207
    return [x == 1 for x in is_in_the_same_node.tolist()]