distributed_utils.py 6.25 KB
Newer Older
PengGao's avatar
PengGao committed
1
import os
gaclove's avatar
gaclove committed
2
3
import pickle
from typing import Any, Optional
PengGao's avatar
PengGao committed
4

PengGao's avatar
PengGao committed
5
6
7
8
import torch
import torch.distributed as dist
from loguru import logger

gaclove's avatar
gaclove committed
9
10
from .gpu_manager import gpu_manager

PengGao's avatar
PengGao committed
11
12
13
14
15
16

class DistributedManager:
    def __init__(self):
        self.is_initialized = False
        self.rank = 0
        self.world_size = 1
gaclove's avatar
gaclove committed
17
        self.device = "cpu"
PengGao's avatar
PengGao committed
18
19
20
21
22
23
24
25

    def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
        try:
            os.environ["RANK"] = str(rank)
            os.environ["WORLD_SIZE"] = str(world_size)
            os.environ["MASTER_ADDR"] = master_addr
            os.environ["MASTER_PORT"] = master_port

gaclove's avatar
gaclove committed
26
27
28
29
            backend = "nccl" if torch.cuda.is_available() else "gloo"

            dist.init_process_group(backend=backend, init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)
            logger.info(f"Setup backend: {backend}")
PengGao's avatar
PengGao committed
30

gaclove's avatar
gaclove committed
31
            self.device = gpu_manager.set_device_for_rank(rank, world_size)
PengGao's avatar
PengGao committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

            self.is_initialized = True
            self.rank = rank
            self.world_size = world_size

            logger.info(f"Rank {rank}/{world_size - 1} distributed environment initialized successfully")
            return True

        except Exception as e:
            logger.error(f"Rank {rank} distributed environment initialization failed: {str(e)}")
            return False

    def cleanup(self):
        try:
            if dist.is_initialized():
                dist.destroy_process_group()
                logger.info(f"Rank {self.rank} distributed environment cleaned up")
        except Exception as e:
            logger.error(f"Rank {self.rank} error occurred while cleaning up distributed environment: {str(e)}")
        finally:
            self.is_initialized = False

    def barrier(self):
        if self.is_initialized:
gaclove's avatar
gaclove committed
56
57
58
59
            if torch.cuda.is_available() and dist.get_backend() == "nccl":
                dist.barrier(device_ids=[torch.cuda.current_device()])
            else:
                dist.barrier()
PengGao's avatar
PengGao committed
60
61
62
63

    def is_rank_zero(self) -> bool:
        return self.rank == 0

gaclove's avatar
gaclove committed
64
    def broadcast_task_data(self, task_data: Optional[Any] = None) -> Optional[Any]:
PengGao's avatar
PengGao committed
65
66
67
        if not self.is_initialized:
            return None

gaclove's avatar
gaclove committed
68
69
70
71
72
73
74
75
76
77
        try:
            backend = dist.get_backend() if dist.is_initialized() else "gloo"
        except Exception:
            backend = "gloo"

        if backend == "gloo":
            broadcast_device = torch.device("cpu")
        else:
            broadcast_device = torch.device(self.device if self.device != "cpu" else "cpu")

PengGao's avatar
PengGao committed
78
79
        if self.is_rank_zero():
            if task_data is None:
gaclove's avatar
gaclove committed
80
                stop_signal = torch.tensor([1], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
81
            else:
gaclove's avatar
gaclove committed
82
                stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
83
84
85
86
87

            dist.broadcast(stop_signal, src=0)

            if task_data is not None:
                task_bytes = pickle.dumps(task_data)
gaclove's avatar
gaclove committed
88
                task_length = torch.tensor([len(task_bytes)], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
89
90
91

                dist.broadcast(task_length, src=0)

gaclove's avatar
gaclove committed
92
93
94
95
96
97
98
99
100
101
102
103
                chunk_size = 1024 * 1024
                if len(task_bytes) > chunk_size:
                    num_chunks = (len(task_bytes) + chunk_size - 1) // chunk_size
                    for i in range(num_chunks):
                        start_idx = i * chunk_size
                        end_idx = min((i + 1) * chunk_size, len(task_bytes))
                        chunk = task_bytes[start_idx:end_idx]
                        task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(broadcast_device)
                        dist.broadcast(task_tensor, src=0)
                else:
                    task_tensor = torch.tensor(list(task_bytes), dtype=torch.uint8).to(broadcast_device)
                    dist.broadcast(task_tensor, src=0)
PengGao's avatar
PengGao committed
104
105
106
107
108

                return task_data
            else:
                return None
        else:
gaclove's avatar
gaclove committed
109
            stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
110
111
112
113
114
            dist.broadcast(stop_signal, src=0)

            if stop_signal.item() == 1:
                return None
            else:
gaclove's avatar
gaclove committed
115
                task_length = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
116
117
                dist.broadcast(task_length, src=0)

gaclove's avatar
gaclove committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
                total_length = int(task_length.item())
                chunk_size = 1024 * 1024

                if total_length > chunk_size:
                    task_bytes = bytearray()
                    num_chunks = (total_length + chunk_size - 1) // chunk_size
                    for i in range(num_chunks):
                        chunk_length = min(chunk_size, total_length - len(task_bytes))
                        task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(broadcast_device)
                        dist.broadcast(task_tensor, src=0)
                        task_bytes.extend(task_tensor.cpu().numpy())
                    task_bytes = bytes(task_bytes)
                else:
                    task_tensor = torch.empty(total_length, dtype=torch.uint8).to(broadcast_device)
                    dist.broadcast(task_tensor, src=0)
                    task_bytes = bytes(task_tensor.cpu().numpy())
PengGao's avatar
PengGao committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

                task_data = pickle.loads(task_bytes)
                return task_data


class DistributedWorker:
    def __init__(self, rank: int, world_size: int, master_addr: str, master_port: str):
        self.rank = rank
        self.world_size = world_size
        self.master_addr = master_addr
        self.master_port = master_port
        self.dist_manager = DistributedManager()

    def init(self) -> bool:
        return self.dist_manager.init_process_group(self.rank, self.world_size, self.master_addr, self.master_port)

    def cleanup(self):
        self.dist_manager.cleanup()

    def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs):
        self.dist_manager.barrier()

        if self.dist_manager.is_rank_zero():
            result = {"task_id": task_id, "status": status, **kwargs}
            result_queue.put(result)
            logger.info(f"Task {task_id} {status}")


def create_distributed_worker(rank: int, world_size: int, master_addr: str, master_port: str) -> DistributedWorker:
    return DistributedWorker(rank, world_size, master_addr, master_port)