pynccl_utils.py 1.72 KB
Newer Older
1
2
3
4
import contextlib
from typing import Optional

import torch
5
from torch.distributed import ProcessGroup, ReduceOp
6

7
8
9
from vllm.logger import init_logger

logger = init_logger(__name__)
10
11

try:
12
13
    from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
                                                              ncclGetVersion)
14
15
16
except Exception as e:
    # in non-NVIDIA environments, we can't import the nccl module
    # e.g. when running on machines with AMD GPUs
17
    logger.info("Failed to import NCCL library: %s", e)
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    logger.info("It is expected if you are not running on NVIDIA GPUs.")
    pass

comm: Optional["NCCLCommunicator"] = None


def is_initialized() -> bool:
    """Returns whether the NCCL backend is initialized."""
    return comm is not None


@contextlib.contextmanager
def set_pynccl_stream(stream: torch.cuda.Stream):
    """Set the cuda stream for communication"""
    try:
33
        assert comm is not None
34
35
36
37
38
39
        comm.stream = stream
        yield
    finally:
        pass


40
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
41
42
    assert not is_initialized()
    global comm
43
    logger.info("vLLM is using nccl==%s", ncclGetVersion())
44
    comm = NCCLCommunicator(group=group)
45
46
47
48
49


def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
    """All-reduces the input tensor across the process group."""
    assert input_.is_cuda, f"{input_} should be a cuda tensor"
50
    assert comm is not None
51
52
53
54
55
56
57
58
59
60
    comm.all_reduce(input_, op)


def destroy_process_group() -> None:
    global comm
    comm = None


def get_world_size() -> int:
    """Returns the world size."""
61
    assert comm is not None
62
63
64
    return comm.world_size


65
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
66
    return comm