process_group.py 10.1 KB
Newer Older
1
from typing import List, Optional
2
3
4

import torch

5
from colossalai.context.singleton_meta import SingletonMeta
6
from colossalai.logging import get_dist_logger
7
8
9
10
11
12


class PyTorchProcessGroupDict(metaclass=SingletonMeta):

    def __init__(self):
        # distributed settings
13
        # use this dict to record all Pytorch ProcessGroups
14
        self.dict = {}
15
16
17
18
19
20
21
22
        # set a distributed logger
        self.logger = get_dist_logger('ProcessGroup')

    def log_pg_init(self, rank_list: List[int], backend: str):
        str_list = ["Pytorch ProcessGroup Init:"]
        str_list.append(f"backend: {backend}")
        str_list.append(f"ranks: {rank_list}")
        self.logger.info("\n\t".join(str_list), ranks=[0])
23

24
25
26
27
28
    def get(self, rank_list: List[int], backend: str = 'nccl'):
        """Reuse Pytorch ProcessGroup when such a group is initialized
        """
        # we need to convert the passed list to a tuple
        # since List is unhashable
29
30
31
32
33
        processgroup_key = (backend, tuple(rank_list))
        if processgroup_key not in self.dict:
            self.log_pg_init(rank_list=rank_list, backend=backend)
            self.dict[processgroup_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
        return self.dict[processgroup_key]
34
35
36


PYTORCHPGDICT_ = PyTorchProcessGroupDict()
37
38
39


class ProcessGroup:
40
    """ProcessGroup
41
    Process Group indicates how processes are organized in groups for parallel execution using Tensor Parallelism and Data Parallelism.
42
43
44
45
46

    NOTE, the ProcessGroup must be used after `torch.distributed.initialize()`


    Args:
47
48
49
        rank: the global rank of the current process.
        ranks: List[int], a list of rank id belongings to this process group.
        backend: str, the backend of the process group.
50
        tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
51
        dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
52
53
54
    """

    def __init__(self,
55
56
                 rank: Optional[int] = None,
                 ranks: Optional[List[int]] = None,
57
58
                 tp_degree: Optional[int] = None,
                 dp_degree: Optional[int] = None) -> None:
59
        if not torch.distributed.is_initialized():
60
            self.is_init = False
61
62
            return

63
        assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
64
65
66
67

        self._rank = torch.distributed.get_rank()
        if rank is not None:
            assert self._rank == rank    # make sure that the global rank is correct
68
69
70
71
72

        if ranks is None:
            self._rank_list = list(range(torch.distributed.get_world_size()))
        else:
            self._rank_list = ranks
73
            self._rank_list.sort()    # ensure that the list is in order
74

75
76
77
78
79
        self._world_size = len(self._rank_list)

        if dp_degree is None and tp_degree is None:
            self._dp_degree = self._world_size
            self._tp_degree = 1
80
        elif dp_degree and not tp_degree:
81
82
            self._dp_degree = dp_degree
            assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
83
            self._tp_degree = self._world_size // dp_degree
84
        elif not dp_degree and tp_degree:
85
86
            self._tp_degree = tp_degree
            assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
87
            self._dp_degree = self._world_size // tp_degree
88
89
90
91
92
93
94
        else:
            self._dp_degree = dp_degree
            self._tp_degree = tp_degree
            assert self._dp_degree * self._tp_degree == self._world_size, \
                f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
                f"and TP degree {self._tp_degree}"

95
96
        self._tp_rank_list = None
        self._dp_rank_list = None
97

98
99
100
101
102
103
104
105
106
107
108
        for i in range(self._dp_degree):
            i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
            PYTORCHPGDICT_.get(i_tp_list, 'nccl')
            if self._rank in i_tp_list:
                self._tp_rank_list = i_tp_list

        for j in range(self._tp_degree):
            j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
            PYTORCHPGDICT_.get(j_dp_list, 'nccl')
            if self._rank in j_dp_list:
                self._dp_rank_list = j_dp_list
109

110
        self._has_cpu_groups = False
111
        self.is_init = True
112
113

    def set_cpu_groups(self):
114
        """set_cpu_groups
115
116
        Initialize Pytorch process groups for cpu communications.
        """
117
118
        if self.has_cpu_groups:
            return
119
120
121
122
123
124
125
126
127

        for i in range(self._dp_degree):
            i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
            PYTORCHPGDICT_.get(i_tp_list, 'gloo')

        for j in range(self._tp_degree):
            j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
            PYTORCHPGDICT_.get(j_dp_list, 'gloo')

128
        self._has_cpu_groups = True
129

130
    @property
131
    def has_cpu_groups(self) -> bool:
132
        """has_cpu_groups
133
134
135
136
137
        If cpu groups have been initailized.

        Returns:
            bool: cpu process groups have been initialized or not.
        """
138
        return self._has_cpu_groups
139

140
    def __repr__(self):
141
        if self.is_init:
142
143
144
            ranks_str = f"ProcessGroup(ranks={self._rank_list},\n"
            personal_str = f"             rank={self._rank}, dp={self._dp_degree}, tp={self._tp_degree})"
            return ranks_str + personal_str
145
146
        else:
            return "ProcessGroup not initialized"
147

148
149
150
151
    def __eq__(self, obj: 'ProcessGroup') -> bool:
        if not isinstance(obj, ProcessGroup):
            return False
        if self._rank != obj._rank:
152
            return False
153
        if self._rank_list != obj._rank_list:
154
            return False
155
        if self._tp_rank_list != obj._tp_rank_list:
156
            return False
157
        if self._dp_rank_list != obj._dp_rank_list:
158
            return False
159
160
161
162
163
164
        if self._tp_degree != obj._tp_degree:
            return False
        if self._dp_degree != obj._dp_degree:
            return False
        return True

165
    def rank(self) -> int:
166
        """rank
167
168
169
170
171
172

        The current rank in the global process group.

        Returns:
            int: the rank number
        """
173
        return self._rank
174

175
    def ranks_in_group(self) -> List[int]:
176
        """ranks_in_group
177

178
        a list of rank number in in the global process group.
179
180
181
182

        Returns:
            List[int]: a list of rank number.
        """
183
184
        return self._rank_list

185
186
187
    def world_size(self) -> int:
        """world_size

188
        The world size of the global process group.
189
190
191
192

        Returns:
            int: world size
        """
193
194
        return self._world_size

195
    def tp_rank_list(self) -> List[int]:
196
        """tp_rank_list
197
198
199
200
201
202

        the rank list in the TP process group containing the current rank.

        Returns:
            List[int]: the list of rank number.
        """
203
204
        return self._tp_rank_list

205
    def dp_rank_list(self) -> List[int]:
206
        """dp_rank_list
207
208
209
210
211
212

        the rank list in the DP process group containing the current rank.

        Returns:
            List[int]:  the list of rank number.
        """
213
214
        return self._dp_rank_list

215
    def tp_local_rank(self) -> int:
216
        """tp_local_rank
217
218
219
220
221
222

        The local rank number in the current TP process group.

        Returns:
            int: tp rank number.
        """
223
224
        return self._rank % self._tp_degree

225
226
227
228
229
230
231
232
    def dp_local_rank(self) -> int:
        """dp_local_rank

        The local rank number in the current DP process group.

        Returns:
            int: dp rank number.
        """
233
234
        return self._rank // self._tp_degree

235
236
237
238
239
240
241
242
    def dp_world_size(self) -> int:
        """dp_world_size

        The world size of the current DP process group.

        Returns:
            int: dp world size
        """
243
244
        return len(self._dp_rank_list)

245
246
247
248
249
250
251
252
    def tp_world_size(self) -> int:
        """tp_world_size

        The world size of the current TP process group.

        Returns:
            int: tp world size
        """
253
254
255
        return len(self._tp_rank_list)

    def dp_process_group(self):
256
257
258
259
260
261
262
        """dp_process_group

        the pytorch DP process group containing the current rank.

        Returns:
            `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
        """
263
        return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
264
265

    def tp_process_group(self):
266
267
268
269
270
271
272
        """tp_process_group

        the pytorch TP process group containing the current rank.

        Returns:
            `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
        """
273
        return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
274
275

    def cpu_dp_process_group(self):
276
277
278
        """cpu_dp_process_group

        the pytorch CPU DP process group containing the current rank.
279

280
281
282
283
284
        assert failed if cpu process group is not initialized.

        Returns:
            `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
        """
285
        assert self._has_cpu_groups
286
        return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
287
288

    def cpu_tp_process_group(self):
289
290
291
        """cpu_tp_process_group

        the pytorch CPU TP process group containing the current rank.
292

293
294
295
296
297
        assert failed if cpu process group is not initialized.

        Returns:
            `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
        """
298
        assert self._has_cpu_groups
299
        return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
300

301
302
303
304
305
306
307
308
    def get_ranks_in_dp(self) -> List[int]:
        """get_ranks_in_dp

        ranks in current dp process group.

        Returns:
            List[int]: a list of rank number.
        """
309
310
311
        return self._dp_rank_list

    def get_ranks_in_tp(self):
312
313
314
315
316
317
318
        """get_ranks_in_tp

        ranks in current tp process group.

        Returns:
            List[int]: a list of rank number.
        """
319
        return self._tp_rank_list