process_group.py 6.7 KB
Newer Older
1
2
import torch
from typing import List, Optional
3
from colossalai.logging import get_dist_logger
4
5
6
7
8
9
10
11
12
from colossalai.context.singleton_meta import SingletonMeta


class PyTorchProcessGroupDict(metaclass=SingletonMeta):

    def __init__(self):
        # distributed settings
        self.dict = {}

13
14
15
16
17
18
19
20
21
    def get(self, rank_list: List[int], backend: str = 'nccl'):
        """Reuse Pytorch ProcessGroup when such a group is initialized
        """
        rank_tuple = tuple(rank_list)
        # we need to convert the passed list to a tuple
        # since List is unhashable
        pg_key = (backend, rank_tuple)

        if pg_key not in self.dict:
22
23

            self.logger = get_dist_logger('ProcessGroup')
24
            self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0])
25
26
            self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
        return self.dict[pg_key]
27
28
29


PYTORCHPGDICT_ = PyTorchProcessGroupDict()
30
31
32
33
34


class ProcessGroup:
    """
    Process Group contains group partition for Tensor Parallel and Data Parallel.
35
    NOTE, the ProcessGroup must be used after torch.distributed.initialize()
36
37
38
39
40
41
42
43
44
    args:
        rank: the global rank of the current process.
        ranks: List[int], a list of rank id belongings to this process group.
        backend: str, the backend of the process group.
        tp_degree: Optional[int], tensor parallelism degree, default None means 1
        dp_degree: Optional[int], data parallelism degree, default None means len(ranks)
    """

    def __init__(self,
45
46
                 rank: Optional[int] = None,
                 ranks: Optional[List[int]] = None,
47
48
                 tp_degree: Optional[int] = None,
                 dp_degree: Optional[int] = None) -> None:
49
        if not torch.distributed.is_initialized():
50
            self.is_init = False
51
52
            return

53
54
55
56
57
58
59
60
61
62
        assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
        if rank is None:
            self._rank = torch.distributed.get_rank()
        else:
            self._rank = rank

        if ranks is None:
            self._rank_list = list(range(torch.distributed.get_world_size()))
        else:
            self._rank_list = ranks
63
            self._rank_list.sort()    # ensure that the list is in order
64

65
66
67
68
69
        self._world_size = len(self._rank_list)

        if dp_degree is None and tp_degree is None:
            self._dp_degree = self._world_size
            self._tp_degree = 1
70
        elif dp_degree and not tp_degree:
71
72
            self._dp_degree = dp_degree
            assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
73
            self._tp_degree = self._world_size // dp_degree
74
        elif not dp_degree and tp_degree:
75
76
            self._tp_degree = tp_degree
            assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
77
            self._dp_degree = self._world_size // tp_degree
78
79
80
81
82
83
84
        else:
            self._dp_degree = dp_degree
            self._tp_degree = tp_degree
            assert self._dp_degree * self._tp_degree == self._world_size, \
                f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
                f"and TP degree {self._tp_degree}"

85
86
        self._tp_rank_list = None
        self._dp_rank_list = None
87

88
89
90
91
92
93
94
95
96
97
98
        for i in range(self._dp_degree):
            i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
            PYTORCHPGDICT_.get(i_tp_list, 'nccl')
            if self._rank in i_tp_list:
                self._tp_rank_list = i_tp_list

        for j in range(self._tp_degree):
            j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
            PYTORCHPGDICT_.get(j_dp_list, 'nccl')
            if self._rank in j_dp_list:
                self._dp_rank_list = j_dp_list
99

100
        self._has_cpu_groups = False
101
        self.is_init = True
102
103
104
105

    def set_cpu_groups(self):
        if self.has_cpu_groups:
            return
106
107
108
109
110
111
112
113
114

        for i in range(self._dp_degree):
            i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
            PYTORCHPGDICT_.get(i_tp_list, 'gloo')

        for j in range(self._tp_degree):
            j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
            PYTORCHPGDICT_.get(j_dp_list, 'gloo')

115
        self._has_cpu_groups = True
116

117
    @property
118
119
    def has_cpu_groups(self):
        return self._has_cpu_groups
120

121
    def __repr__(self):
122
123
124
125
126
        if self.is_init:
            return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\
                format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list)
        else:
            return "ProcessGroup not initialized"
127

128
129
130
131
    def __eq__(self, obj: 'ProcessGroup') -> bool:
        if not isinstance(obj, ProcessGroup):
            return False
        if self._rank != obj._rank:
132
            return False
133
        if self._rank_list != obj._rank_list:
134
            return False
135
        if self._tp_rank_list != obj._tp_rank_list:
136
            return False
137
        if self._dp_rank_list != obj._dp_rank_list:
138
            return False
139
140
141
142
143
144
145
146
        if self._tp_degree != obj._tp_degree:
            return False
        if self._dp_degree != obj._dp_degree:
            return False
        return True

    def rank(self):
        return self._rank
147

148
149
150
    def ranks_in_group(self):
        return self._rank_list

151
152
153
    def world_size(self):
        return self._world_size

154
155
156
157
158
159
    def tp_rank_list(self):
        return self._tp_rank_list

    def dp_rank_list(self):
        return self._dp_rank_list

160
161
162
163
164
165
    def tp_local_rank(self):
        return self._rank % self._tp_degree

    def dp_local_rank(self):
        return self._rank // self._tp_degree

166
167
168
169
170
171
172
    def dp_world_size(self):
        return len(self._dp_rank_list)

    def tp_world_size(self):
        return len(self._tp_rank_list)

    def dp_process_group(self):
173
174
        # return self._dp_process_group
        return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
175
176

    def tp_process_group(self):
177
178
        # return self._tp_process_group
        return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
179
180

    def cpu_dp_process_group(self):
181
        assert self._has_cpu_groups
182
        return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
183
184

    def cpu_tp_process_group(self):
185
        assert self._has_cpu_groups
186
        return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
187
188
189
190
191
192

    def get_ranks_in_dp(self):
        return self._dp_rank_list

    def get_ranks_in_tp(self):
        return self._tp_rank_list