distributed_utils.py 5.44 KB
Newer Older
PengGao's avatar
PengGao committed
1
import os
gaclove's avatar
gaclove committed
2
3
import pickle
from typing import Any, Optional
PengGao's avatar
PengGao committed
4

PengGao's avatar
PengGao committed
5
6
7
8
9
10
11
12
13
14
import torch
import torch.distributed as dist
from loguru import logger


class DistributedManager:
    def __init__(self):
        self.is_initialized = False
        self.rank = 0
        self.world_size = 1
gaclove's avatar
gaclove committed
15
        self.device = "cpu"
PengGao's avatar
PengGao committed
16

17
18
    CHUNK_SIZE = 1024 * 1024

PengGao's avatar
PengGao committed
19
20
    def init_process_group(self) -> bool:
        """Initialize process group using torchrun environment variables"""
PengGao's avatar
PengGao committed
21
        try:
PengGao's avatar
PengGao committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
            # torchrun sets these environment variables automatically
            self.rank = int(os.environ.get("LOCAL_RANK", 0))
            self.world_size = int(os.environ.get("WORLD_SIZE", 1))

            if self.world_size > 1:
                # torchrun handles backend, init_method, rank, and world_size
                # We just need to call init_process_group without parameters
                backend = "nccl" if torch.cuda.is_available() else "gloo"
                dist.init_process_group(backend=backend, init_method="env://")
                logger.info(f"Setup backend: {backend}")

                # Set CUDA device for this rank
                if torch.cuda.is_available():
                    torch.cuda.set_device(self.rank)
                    self.device = f"cuda:{self.rank}"
                else:
                    self.device = "cpu"
            else:
                self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
PengGao's avatar
PengGao committed
41
42

            self.is_initialized = True
PengGao's avatar
PengGao committed
43
            logger.info(f"Rank {self.rank}/{self.world_size - 1} distributed environment initialized successfully")
PengGao's avatar
PengGao committed
44
45
46
            return True

        except Exception as e:
PengGao's avatar
PengGao committed
47
            logger.error(f"Rank {self.rank} distributed environment initialization failed: {str(e)}")
PengGao's avatar
PengGao committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
            return False

    def cleanup(self):
        try:
            if dist.is_initialized():
                dist.destroy_process_group()
                logger.info(f"Rank {self.rank} distributed environment cleaned up")
        except Exception as e:
            logger.error(f"Rank {self.rank} error occurred while cleaning up distributed environment: {str(e)}")
        finally:
            self.is_initialized = False

    def barrier(self):
        if self.is_initialized:
gaclove's avatar
gaclove committed
62
63
64
65
            if torch.cuda.is_available() and dist.get_backend() == "nccl":
                dist.barrier(device_ids=[torch.cuda.current_device()])
            else:
                dist.barrier()
PengGao's avatar
PengGao committed
66
67
68
69

    def is_rank_zero(self) -> bool:
        return self.rank == 0

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    def _broadcast_byte_chunks(self, data_bytes: bytes, device: torch.device) -> None:
        total_length = len(data_bytes)
        num_full_chunks = total_length // self.CHUNK_SIZE
        remaining = total_length % self.CHUNK_SIZE

        for i in range(num_full_chunks):
            start_idx = i * self.CHUNK_SIZE
            end_idx = start_idx + self.CHUNK_SIZE
            chunk = data_bytes[start_idx:end_idx]
            task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
            dist.broadcast(task_tensor, src=0)

        if remaining:
            chunk = data_bytes[-remaining:]
            task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
            dist.broadcast(task_tensor, src=0)

    def _receive_byte_chunks(self, total_length: int, device: torch.device) -> bytes:
        if total_length <= 0:
            return b""

        received = bytearray()
        remaining = total_length

        while remaining > 0:
            chunk_length = min(self.CHUNK_SIZE, remaining)
            task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(device)
            dist.broadcast(task_tensor, src=0)
            received.extend(task_tensor.cpu().numpy())
            remaining -= chunk_length

        return bytes(received)

gaclove's avatar
gaclove committed
103
    def broadcast_task_data(self, task_data: Optional[Any] = None) -> Optional[Any]:
PengGao's avatar
PengGao committed
104
105
106
        if not self.is_initialized:
            return None

gaclove's avatar
gaclove committed
107
108
109
110
111
112
113
114
115
116
        try:
            backend = dist.get_backend() if dist.is_initialized() else "gloo"
        except Exception:
            backend = "gloo"

        if backend == "gloo":
            broadcast_device = torch.device("cpu")
        else:
            broadcast_device = torch.device(self.device if self.device != "cpu" else "cpu")

PengGao's avatar
PengGao committed
117
118
        if self.is_rank_zero():
            if task_data is None:
gaclove's avatar
gaclove committed
119
                stop_signal = torch.tensor([1], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
120
            else:
gaclove's avatar
gaclove committed
121
                stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
122
123
124
125
126

            dist.broadcast(stop_signal, src=0)

            if task_data is not None:
                task_bytes = pickle.dumps(task_data)
gaclove's avatar
gaclove committed
127
                task_length = torch.tensor([len(task_bytes)], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
128
129

                dist.broadcast(task_length, src=0)
130
                self._broadcast_byte_chunks(task_bytes, broadcast_device)
PengGao's avatar
PengGao committed
131
132
133
134
135

                return task_data
            else:
                return None
        else:
gaclove's avatar
gaclove committed
136
            stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
137
138
139
140
141
            dist.broadcast(stop_signal, src=0)

            if stop_signal.item() == 1:
                return None
            else:
gaclove's avatar
gaclove committed
142
                task_length = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
PengGao's avatar
PengGao committed
143

144
                dist.broadcast(task_length, src=0)
gaclove's avatar
gaclove committed
145
                total_length = int(task_length.item())
PengGao's avatar
PengGao committed
146

147
                task_bytes = self._receive_byte_chunks(total_length, broadcast_device)
PengGao's avatar
PengGao committed
148
149
                task_data = pickle.loads(task_bytes)
                return task_data