dist.py 2.76 KB
Newer Older
1
2
3
4
import os
import torch

from datetime import timedelta
5
from loguru import logger
Nicolas Patry's avatar
Nicolas Patry committed
6
from text_generation_server.utils.import_utils import SYSTEM
7

8
9
10
11
12
13
14
# Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))

# CUDA memory fraction
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))

15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class FakeBarrier:
    def wait(self):
        pass


class FakeGroup:
    def __init__(self, rank, size):
        self._rank = rank
        self._size = size

    def allreduce(self, *args, **kwargs):
        return FakeBarrier()

    def allgather(self, inputs, local_tensor, **kwargs):
        assert (
            len(inputs[0]) == len(local_tensor) == 1
        ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
        for input_ in inputs:
            input_[0].data = local_tensor[0].data
        return FakeBarrier()

    def barrier(self, *args, **kwargs):
        return FakeBarrier()

    def size(self):
        return self._size

    def rank(self):
        return self._rank


47
48
49
50
51
def initialize_torch_distributed():
    if torch.cuda.is_available():
        from torch.distributed import ProcessGroupNCCL

        # Set the device id.
52
53
        assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
        device = RANK % torch.cuda.device_count()
54
        torch.cuda.set_device(device)
55
        torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
56
57
58
        backend = "nccl"
        options = ProcessGroupNCCL.Options()
        options.is_high_priority_stream = True
59
        options._timeout = timedelta(seconds=120)
60
    else:
Wang, Yi's avatar
Wang, Yi committed
61
        backend = "gloo"
62
63
        options = None

64
65
    if WORLD_SIZE == 1:
        return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
66
67
    else:
        if os.getenv("DEBUG", None) == "1":
68
            return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
69
70
71

        if not torch.distributed.is_initialized():
            # Call the init process.
Nicolas Patry's avatar
Nicolas Patry committed
72
            if SYSTEM == "ipex":
Wang, Yi's avatar
Wang, Yi committed
73
74
75
76
77
78
                import intel_extension_for_pytorch as ipex

                ipex.distributed.init_process_group(
                    backend="ccl",
                    world_size=WORLD_SIZE,
                    rank=RANK,
79
                    timeout=timedelta(seconds=120),
Wang, Yi's avatar
Wang, Yi committed
80
81
82
83
84
85
86
                    pg_options=options,
                )
            else:
                torch.distributed.init_process_group(
                    backend=backend,
                    world_size=WORLD_SIZE,
                    rank=RANK,
87
                    timeout=timedelta(seconds=120),
Wang, Yi's avatar
Wang, Yi committed
88
89
                    pg_options=options,
                )
90
91
        else:
            logger.warning("torch.distributed is already initialized.")
92

93
        return torch.distributed.group.WORLD, RANK, WORLD_SIZE