distributed_utils.py 6.11 KB
Newer Older
PengGao's avatar
PengGao committed
1
import os
gaclove's avatar
gaclove committed
2
3
import pickle
from typing import Any, Optional
PengGao's avatar
PengGao committed
4

PengGao's avatar
PengGao committed
5
6
7
8
import torch
import torch.distributed as dist
from loguru import logger

gaclove's avatar
gaclove committed
9
10
from .gpu_manager import gpu_manager

PengGao's avatar
PengGao committed
11
12
13
14
15
16

class DistributedManager:
    def __init__(self):
        self.is_initialized = False
        self.rank = 0
        self.world_size = 1
gaclove's avatar
gaclove committed
17
        self.device = "cpu"
PengGao's avatar
PengGao committed
18

19
20
    CHUNK_SIZE = 1024 * 1024

PengGao's avatar
PengGao committed
21
22
23
24
25
26
27
    def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
        try:
            os.environ["RANK"] = str(rank)
            os.environ["WORLD_SIZE"] = str(world_size)
            os.environ["MASTER_ADDR"] = master_addr
            os.environ["MASTER_PORT"] = master_port

gaclove's avatar
gaclove committed
28
29
30
31
            backend = "nccl" if torch.cuda.is_available() else "gloo"

            dist.init_process_group(backend=backend, init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)
            logger.info(f"Setup backend: {backend}")
PengGao's avatar
PengGao committed
32

gaclove's avatar
gaclove committed
33
            self.device = gpu_manager.set_device_for_rank(rank, world_size)
PengGao's avatar
PengGao committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

            self.is_initialized = True
            self.rank = rank
            self.world_size = world_size

            logger.info(f"Rank {rank}/{world_size - 1} distributed environment initialized successfully")
            return True

        except Exception as e:
            logger.error(f"Rank {rank} distributed environment initialization failed: {str(e)}")
            return False

    def cleanup(self):
        try:
            if dist.is_initialized():
                dist.destroy_process_group()
                logger.info(f"Rank {self.rank} distributed environment cleaned up")
        except Exception as e:
            logger.error(f"Rank {self.rank} error occurred while cleaning up distributed environment: {str(e)}")
        finally:
            self.is_initialized = False

    def barrier(self):
        if self.is_initialized:
gaclove's avatar
gaclove committed
58
59
60
61
            if torch.cuda.is_available() and dist.get_backend() == "nccl":
                dist.barrier(device_ids=[torch.cuda.current_device()])
            else:
                dist.barrier()
PengGao's avatar
PengGao committed
62
63
64
65

    def is_rank_zero(self) -> bool:
        return self.rank == 0

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    def _broadcast_byte_chunks(self, data_bytes: bytes, device: torch.device) -> None:
        total_length = len(data_bytes)
        num_full_chunks = total_length // self.CHUNK_SIZE
        remaining = total_length % self.CHUNK_SIZE

        for i in range(num_full_chunks):
            start_idx = i * self.CHUNK_SIZE
            end_idx = start_idx + self.CHUNK_SIZE
            chunk = data_bytes[start_idx:end_idx]
            task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
            dist.broadcast(task_tensor, src=0)

        if remaining:
            chunk = data_bytes[-remaining:]
            task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
            dist.broadcast(task_tensor, src=0)

    def _receive_byte_chunks(self, total_length: int, device: torch.device) -> bytes:
        if total_length <= 0:
            return b""

        received = bytearray()
        remaining = total_length

        while remaining > 0:
            chunk_length = min(self.CHUNK_SIZE, remaining)
            task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(device)
            dist.broadcast(task_tensor, src=0)
            received.extend(task_tensor.cpu().numpy())
            remaining -= chunk_length

        return bytes(received)

gaclove's avatar
gaclove committed
99
    def broadcast_task_data(self, task_data: Optional[Any] = None) -> Optional[Any]:
PengGao's avatar
PengGao committed
100
101
102
        if not self.is_initialized:
            return None

gaclove's avatar
gaclove committed
103
104
105
106
107
108
109
110
111
112
        try:
            backend = dist.get_backend() if dist.is_initialized() else "gloo"
        except Exception:
            backend = "gloo"

        if backend == "gloo":
            broadcast_device = torch.device("cpu")
        else:
            broadcast_device = torch.device(self.device if self.device != "cpu" else "cpu")

PengGao's avatar
PengGao committed
113
114
        if self.is_rank_zero():
            if task_data is None:
gaclove's avatar
gaclove committed
115
                stop_signal = torch.tensor([1], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
116
            else:
gaclove's avatar
gaclove committed
117
                stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
118
119
120
121
122

            dist.broadcast(stop_signal, src=0)

            if task_data is not None:
                task_bytes = pickle.dumps(task_data)
gaclove's avatar
gaclove committed
123
                task_length = torch.tensor([len(task_bytes)], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
124
125

                dist.broadcast(task_length, src=0)
126
                self._broadcast_byte_chunks(task_bytes, broadcast_device)
PengGao's avatar
PengGao committed
127
128
129
130
131

                return task_data
            else:
                return None
        else:
gaclove's avatar
gaclove committed
132
            stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
133
134
135
136
137
            dist.broadcast(stop_signal, src=0)

            if stop_signal.item() == 1:
                return None
            else:
gaclove's avatar
gaclove committed
138
                task_length = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
139

140
                dist.broadcast(task_length, src=0)
gaclove's avatar
gaclove committed
141
                total_length = int(task_length.item())
PengGao's avatar
PengGao committed
142

143
                task_bytes = self._receive_byte_chunks(total_length, broadcast_device)
PengGao's avatar
PengGao committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
                task_data = pickle.loads(task_bytes)
                return task_data


class DistributedWorker:
    def __init__(self, rank: int, world_size: int, master_addr: str, master_port: str):
        self.rank = rank
        self.world_size = world_size
        self.master_addr = master_addr
        self.master_port = master_port
        self.dist_manager = DistributedManager()

    def init(self) -> bool:
        return self.dist_manager.init_process_group(self.rank, self.world_size, self.master_addr, self.master_port)

    def cleanup(self):
        self.dist_manager.cleanup()

    def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs):
        self.dist_manager.barrier()

        if self.dist_manager.is_rank_zero():
            result = {"task_id": task_id, "status": status, **kwargs}
            result_queue.put(result)
            logger.info(f"Task {task_id} {status}")


def create_distributed_worker(rank: int, world_size: int, master_addr: str, master_port: str) -> DistributedWorker:
    return DistributedWorker(rank, world_size, master_addr, master_port)