pynccl_utils.py 1.95 KB
Newer Older
1
2
3
4
5
6
import contextlib
from typing import Optional

import torch
from torch.distributed import ReduceOp

7
8
9
from vllm.logger import init_logger

logger = init_logger(__name__)
10
11

try:
12
13
    from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
                                                              ncclGetVersion)
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
except Exception as e:
    # in non-NVIDIA environments, we can't import the nccl module
    # e.g. when running on machines with AMD GPUs
    logger.info(f"Failed to import NCCL library: {e}")
    logger.info("It is expected if you are not running on NVIDIA GPUs.")
    pass

comm: Optional["NCCLCommunicator"] = None


def is_initialized() -> bool:
    """Returns whether the NCCL backend is initialized."""
    return comm is not None


@contextlib.contextmanager
def set_pynccl_stream(stream: torch.cuda.Stream):
    """Set the cuda stream for communication"""
    try:
33
        assert comm is not None
34
35
36
37
38
39
        comm.stream = stream
        yield
    finally:
        pass


40
41
42
43
def init_process_group(world_size: int,
                       rank: int,
                       init_method: str,
                       local_rank: int = -1) -> None:
44
45
    assert not is_initialized()
    global comm
46
    logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
47
48
    comm = NCCLCommunicator(init_method=init_method,
                            world_size=world_size,
49
                            local_rank=local_rank,
50
51
52
53
54
55
                            rank=rank)


def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
    """All-reduces the input tensor across the process group."""
    assert input_.is_cuda, f"{input_} should be a cuda tensor"
56
    assert comm is not None
57
58
59
60
61
62
63
64
65
66
    comm.all_reduce(input_, op)


def destroy_process_group() -> None:
    global comm
    comm = None


def get_world_size() -> int:
    """Returns the world size."""
67
    assert comm is not None
68
69
70
    return comm.world_size


71
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
72
    return comm