distributed_utils.py 4.39 KB
Newer Older
PengGao's avatar
PengGao committed
1
import os
PengGao's avatar
PengGao committed
2

PengGao's avatar
PengGao committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.distributed as dist
from loguru import logger


class DistributedManager:
    def __init__(self):
        self.is_initialized = False
        self.rank = 0
        self.world_size = 1

    def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
        try:
            os.environ["RANK"] = str(rank)
            os.environ["WORLD_SIZE"] = str(world_size)
            os.environ["MASTER_ADDR"] = master_addr
            os.environ["MASTER_PORT"] = master_port

            dist.init_process_group(backend="nccl", init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)

            if torch.cuda.is_available():  # type: ignore
                torch.cuda.set_device(rank)

            self.is_initialized = True
            self.rank = rank
            self.world_size = world_size

            logger.info(f"Rank {rank}/{world_size - 1} distributed environment initialized successfully")
            return True

        except Exception as e:
            logger.error(f"Rank {rank} distributed environment initialization failed: {str(e)}")
            return False

    def cleanup(self):
        try:
            if dist.is_initialized():
                dist.destroy_process_group()
                logger.info(f"Rank {self.rank} distributed environment cleaned up")
        except Exception as e:
            logger.error(f"Rank {self.rank} error occurred while cleaning up distributed environment: {str(e)}")
        finally:
            self.is_initialized = False

    def barrier(self):
        if self.is_initialized:
            dist.barrier()

    def is_rank_zero(self) -> bool:
        return self.rank == 0

    def broadcast_task_data(self, task_data=None):  # type: ignore
        if not self.is_initialized:
            return None

        if self.is_rank_zero():
            if task_data is None:
                stop_signal = torch.tensor([1], dtype=torch.int32, device=f"cuda:{self.rank}")
            else:
                stop_signal = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")

            dist.broadcast(stop_signal, src=0)

            if task_data is not None:
                import pickle

                task_bytes = pickle.dumps(task_data)
                task_length = torch.tensor([len(task_bytes)], dtype=torch.int32, device=f"cuda:{self.rank}")

                dist.broadcast(task_length, src=0)

                task_tensor = torch.tensor(list(task_bytes), dtype=torch.uint8, device=f"cuda:{self.rank}")
                dist.broadcast(task_tensor, src=0)

                return task_data
            else:
                return None
        else:
            stop_signal = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")
            dist.broadcast(stop_signal, src=0)

            if stop_signal.item() == 1:
                return None
            else:
                task_length = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")
                dist.broadcast(task_length, src=0)

                task_tensor = torch.empty(int(task_length.item()), dtype=torch.uint8, device=f"cuda:{self.rank}")
                dist.broadcast(task_tensor, src=0)

                import pickle

                task_bytes = bytes(task_tensor.cpu().numpy())
                task_data = pickle.loads(task_bytes)

                return task_data


class DistributedWorker:
    def __init__(self, rank: int, world_size: int, master_addr: str, master_port: str):
        self.rank = rank
        self.world_size = world_size
        self.master_addr = master_addr
        self.master_port = master_port
        self.dist_manager = DistributedManager()

    def init(self) -> bool:
        return self.dist_manager.init_process_group(self.rank, self.world_size, self.master_addr, self.master_port)

    def cleanup(self):
        self.dist_manager.cleanup()

    def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs):
        # Synchronize all processes
        self.dist_manager.barrier()

        if self.dist_manager.is_rank_zero():
            result = {"task_id": task_id, "status": status, **kwargs}
            result_queue.put(result)
            logger.info(f"Task {task_id} {status}")


def create_distributed_worker(rank: int, world_size: int, master_addr: str, master_port: str) -> DistributedWorker:
    return DistributedWorker(rank, world_size, master_addr, master_port)