distributed_utils.py 5 KB
Newer Older
PengGao's avatar
PengGao committed
1
import os
gaclove's avatar
gaclove committed
2
import pickle
PengGao's avatar
PengGao committed
3
from datetime import timedelta
gaclove's avatar
gaclove committed
4
from typing import Any, Optional
PengGao's avatar
PengGao committed
5

PengGao's avatar
PengGao committed
6
7
8
9
10
11
12
13
14
15
import torch
import torch.distributed as dist
from loguru import logger


class DistributedManager:
    def __init__(self):
        self.is_initialized = False
        self.rank = 0
        self.world_size = 1
gaclove's avatar
gaclove committed
16
        self.device = "cpu"
PengGao's avatar
PengGao committed
17
        self.task_pg = None
PengGao's avatar
PengGao committed
18

19
20
    CHUNK_SIZE = 1024 * 1024

PengGao's avatar
PengGao committed
21
    def init_process_group(self) -> bool:
PengGao's avatar
PengGao committed
22
        try:
PengGao's avatar
PengGao committed
23
24
25
26
27
28
29
30
            self.rank = int(os.environ.get("LOCAL_RANK", 0))
            self.world_size = int(os.environ.get("WORLD_SIZE", 1))

            if self.world_size > 1:
                backend = "nccl" if torch.cuda.is_available() else "gloo"
                dist.init_process_group(backend=backend, init_method="env://")
                logger.info(f"Setup backend: {backend}")

PengGao's avatar
PengGao committed
31
32
33
34
                task_timeout = timedelta(days=30)
                self.task_pg = dist.new_group(backend="gloo", timeout=task_timeout)
                logger.info("Created gloo process group for task distribution with 30-day timeout")

PengGao's avatar
PengGao committed
35
36
37
38
39
40
41
                if torch.cuda.is_available():
                    torch.cuda.set_device(self.rank)
                    self.device = f"cuda:{self.rank}"
                else:
                    self.device = "cpu"
            else:
                self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
PengGao's avatar
PengGao committed
42
43

            self.is_initialized = True
PengGao's avatar
PengGao committed
44
            logger.info(f"Rank {self.rank}/{self.world_size - 1} distributed environment initialized successfully")
PengGao's avatar
PengGao committed
45
46
47
            return True

        except Exception as e:
PengGao's avatar
PengGao committed
48
            logger.error(f"Rank {self.rank} distributed environment initialization failed: {str(e)}")
PengGao's avatar
PengGao committed
49
50
51
52
53
54
55
56
57
58
59
            return False

    def cleanup(self):
        try:
            if dist.is_initialized():
                dist.destroy_process_group()
                logger.info(f"Rank {self.rank} distributed environment cleaned up")
        except Exception as e:
            logger.error(f"Rank {self.rank} error occurred while cleaning up distributed environment: {str(e)}")
        finally:
            self.is_initialized = False
PengGao's avatar
PengGao committed
60
            self.task_pg = None
PengGao's avatar
PengGao committed
61
62
63

    def barrier(self):
        if self.is_initialized:
gaclove's avatar
gaclove committed
64
65
66
67
            if torch.cuda.is_available() and dist.get_backend() == "nccl":
                dist.barrier(device_ids=[torch.cuda.current_device()])
            else:
                dist.barrier()
PengGao's avatar
PengGao committed
68
69
70
71

    def is_rank_zero(self) -> bool:
        return self.rank == 0

PengGao's avatar
PengGao committed
72
    def _broadcast_byte_chunks(self, data_bytes: bytes) -> None:
73
74
75
76
77
78
79
80
        total_length = len(data_bytes)
        num_full_chunks = total_length // self.CHUNK_SIZE
        remaining = total_length % self.CHUNK_SIZE

        for i in range(num_full_chunks):
            start_idx = i * self.CHUNK_SIZE
            end_idx = start_idx + self.CHUNK_SIZE
            chunk = data_bytes[start_idx:end_idx]
PengGao's avatar
PengGao committed
81
82
            task_tensor = torch.tensor(list(chunk), dtype=torch.uint8)
            dist.broadcast(task_tensor, src=0, group=self.task_pg)
83
84
85

        if remaining:
            chunk = data_bytes[-remaining:]
PengGao's avatar
PengGao committed
86
87
            task_tensor = torch.tensor(list(chunk), dtype=torch.uint8)
            dist.broadcast(task_tensor, src=0, group=self.task_pg)
88

PengGao's avatar
PengGao committed
89
    def _receive_byte_chunks(self, total_length: int) -> bytes:
90
91
92
93
94
95
96
97
        if total_length <= 0:
            return b""

        received = bytearray()
        remaining = total_length

        while remaining > 0:
            chunk_length = min(self.CHUNK_SIZE, remaining)
PengGao's avatar
PengGao committed
98
99
100
            task_tensor = torch.empty(chunk_length, dtype=torch.uint8)
            dist.broadcast(task_tensor, src=0, group=self.task_pg)
            received.extend(task_tensor.numpy())
101
102
103
104
            remaining -= chunk_length

        return bytes(received)

gaclove's avatar
gaclove committed
105
    def broadcast_task_data(self, task_data: Optional[Any] = None) -> Optional[Any]:
PengGao's avatar
PengGao committed
106
107
108
109
110
        if not self.is_initialized:
            return None

        if self.is_rank_zero():
            if task_data is None:
PengGao's avatar
PengGao committed
111
                stop_signal = torch.tensor([1], dtype=torch.int32)
PengGao's avatar
PengGao committed
112
            else:
PengGao's avatar
PengGao committed
113
                stop_signal = torch.tensor([0], dtype=torch.int32)
PengGao's avatar
PengGao committed
114

PengGao's avatar
PengGao committed
115
            dist.broadcast(stop_signal, src=0, group=self.task_pg)
PengGao's avatar
PengGao committed
116
117
118

            if task_data is not None:
                task_bytes = pickle.dumps(task_data)
PengGao's avatar
PengGao committed
119
                task_length = torch.tensor([len(task_bytes)], dtype=torch.int32)
PengGao's avatar
PengGao committed
120

PengGao's avatar
PengGao committed
121
122
                dist.broadcast(task_length, src=0, group=self.task_pg)
                self._broadcast_byte_chunks(task_bytes)
PengGao's avatar
PengGao committed
123
124
125
126
127

                return task_data
            else:
                return None
        else:
PengGao's avatar
PengGao committed
128
129
            stop_signal = torch.tensor([0], dtype=torch.int32)
            dist.broadcast(stop_signal, src=0, group=self.task_pg)
PengGao's avatar
PengGao committed
130
131
132
133

            if stop_signal.item() == 1:
                return None
            else:
PengGao's avatar
PengGao committed
134
                task_length = torch.tensor([0], dtype=torch.int32)
PengGao's avatar
PengGao committed
135

PengGao's avatar
PengGao committed
136
                dist.broadcast(task_length, src=0, group=self.task_pg)
gaclove's avatar
gaclove committed
137
                total_length = int(task_length.item())
PengGao's avatar
PengGao committed
138

PengGao's avatar
PengGao committed
139
                task_bytes = self._receive_byte_chunks(total_length)
PengGao's avatar
PengGao committed
140
141
                task_data = pickle.loads(task_bytes)
                return task_data