pynccl_utils.py 1.83 KB
Newer Older
1
2
3
4
5
6
import contextlib
from typing import Optional

import torch
from torch.distributed import ReduceOp

7
8
9
from vllm.logger import init_logger

logger = init_logger(__name__)
10
11

try:
12
13
    from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
                                                              ncclGetVersion)
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
except Exception as e:
    # in non-NVIDIA environments, we can't import the nccl module
    # e.g. when running on machines with AMD GPUs
    logger.info(f"Failed to import NCCL library: {e}")
    logger.info("It is expected if you are not running on NVIDIA GPUs.")
    pass

comm: Optional["NCCLCommunicator"] = None


def is_initialized() -> bool:
    """Returns whether the NCCL backend is initialized."""
    return comm is not None


@contextlib.contextmanager
def set_pynccl_stream(stream: torch.cuda.Stream):
    """Set the cuda stream for communication"""
    try:
        comm.stream = stream
        yield
    finally:
        pass


39
40
41
42
def init_process_group(world_size: int,
                       rank: int,
                       init_method: str,
                       local_rank: int = -1) -> None:
43
44
    assert not is_initialized()
    global comm
45
    logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
46
47
    comm = NCCLCommunicator(init_method=init_method,
                            world_size=world_size,
48
                            local_rank=local_rank,
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
                            rank=rank)


def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
    """All-reduces the input tensor across the process group."""
    assert input_.is_cuda, f"{input_} should be a cuda tensor"
    comm.all_reduce(input_, op)


def destroy_process_group() -> None:
    global comm
    comm = None


def get_world_size() -> int:
    """Returns the world size."""
    return comm.world_size


def get_nccl_backend():
    return comm