device_mesh.py 23.1 KB
Newer Older
1
2
3
4
"""This code is adapted from Alpa
    https://github.com/alpa-projects/alpa/
   with some changes. """

5
import operator
6
from dataclasses import dataclass
7
from functools import reduce
8
from typing import Dict, List, Union
9

10
import torch
11
import torch.distributed as dist
12
13
14
15
16
17
18
from torch.distributed import ProcessGroup


@dataclass
class ProcessGroupContainer:
    process_group: ProcessGroup
    ranks: List[int]
19
20


21
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
22
class DeviceMesh:
23
24
    """A logical view of a physical cluster. For example, we could view a physical cluster
    with 16 devices as a device mesh with shape (2, 2, 4) or (4, 4).
25

26
27
    Arguments:
        physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
28
29
        logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
        mesh_shape (torch.Size, optional): shape of logical view.
30
31
32
33
        mesh_alpha (List[float], optional): coefficients used for computing
            communication cost (default: None)
        mesh_beta (List[float], optional): coefficients used for computing
            communication cost (default: None)
34
35
36
37
        init_process_group (bool, optional): initialize logical process group
            during initializing the DeviceMesh instance if the init_process_group set to True.
            Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
            (default: False)
38
        device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
39
40
    """

41
    _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo", "npu": "hccl"}
42

43
44
45
46
47
48
49
50
51
52
    def __init__(
        self,
        physical_mesh_id: torch.Tensor,
        mesh_shape: torch.Size = None,
        logical_mesh_id: torch.Tensor = None,
        mesh_alpha: List[float] = None,
        mesh_beta: List[float] = None,
        init_process_group: bool = False,
        device: str = "cuda",
    ):
53
54
55
56
57
58
59
60
61
        # ============================
        # Physical & Logical Mesh IDs
        # ============================
        self._physical_mesh_id = physical_mesh_id
        assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor."

        # logical mesh ids can be obtained via two ways
        # 1. provide physical mesh id and provide mesh shape
        # 2. directly supply the logical mesh id
62
63
        assert mesh_shape is None or logical_mesh_id is None, (
            "Only one of mesh_shape and logical_mesh_id can be specified."
64
            "Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
65
        )
66

67
        if logical_mesh_id is None:
68
69
            self._mesh_shape = mesh_shape
            self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape)
70
71
        else:
            self._logical_mesh_id = logical_mesh_id
72
73
74
75
76
            self._mesh_shape = self._logical_mesh_id.shape

        # ensure two things:
        # 1. logical and physical mesh IDs should contain the same elements
        # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
77
78
79
80
81
82
83
84
85
        assert torch.equal(
            torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)
        ), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
        assert (
            torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel()
        ), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
        assert (
            torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel()
        ), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
86

87
        # ===============================================
88
        # coefficient for alpha-beta communication model
89
90
91
        # alpha is latency and beta is bandwidth
        # ===============================================
        # if the values are not provided, we assume they are 1 for simplicity
92
        if mesh_alpha is None:
93
            mesh_alpha = [1] * len(self._mesh_shape)
94
        if mesh_beta is None:
95
96
            mesh_beta = [1] * len(self._mesh_shape)

97
98
        self.mesh_alpha = tuple(mesh_alpha)
        self.mesh_beta = tuple(mesh_beta)
99
100

        # ensure the alpha and beta have the same shape
101
102
103
        assert len(self.mesh_alpha) == len(
            self.mesh_beta
        ), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

        # =========================
        # Device for Process Group
        # =========================
        self._device = device
        self._dist_backend = self._DIST_BACKEND[device]

        # =========================
        # Process Group Management
        # =========================
        # the _global_to_local_rank_mapping is structured as follows
        # {
        #    <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
        # }
        self._global_to_local_rank_mapping = dict()
119
120
121
        self._init_global_to_logical_rank_mapping(
            mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id
        )
122
123
124
125
126
127
128

        # create process group
        self._process_group_dict = {}
        self._ranks_in_the_process_group = {}
        self._global_rank_of_current_process = None
        self._is_initialized = False

129
        # attribute used to indicate whether this object
130
131
132
133
134
135
136
137
138
139
140
        # is created using DeviceMesh.from_process_group
        # this attribute can be used to do some check in methods
        # such get_process_group as no global rank information
        # is known if created with from_process_group
        self._is_init_from_process_group = False

        # initialize process group if specified
        self._init_ranks_in_the_same_group()
        self._init_process_group = init_process_group
        if init_process_group:
            self.init_logical_process_group()
141
142

    @property
143
144
145
146
147
    def shape(self) -> torch.Size:
        """
        Return the shape of the logical mesh.
        """
        return self._mesh_shape
148
149

    @property
150
151
152
153
154
    def num_devices(self) -> int:
        """
        Return the number of devices contained in the device mesh.
        """
        return reduce(operator.mul, self._physical_mesh_id.shape, 1)
155
156

    @property
157
158
159
160
    def logical_mesh_id(self) -> torch.Tensor:
        """
        Return the logical mesh id.
        """
161
162
        return self._logical_mesh_id

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    @property
    def is_initialized(self) -> bool:
        """
        Return whether the process group is initialized.
        """
        return self._is_initialized

    @staticmethod
    def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh":
        """
        Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method
        will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication.

        Args:
            process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh.
                If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects,
                the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh.

        Returns:
            DeviceMesh: the device mesh instance.
        """

        def _get_device_by_backend(process_group):
            """
            Get the device type given a process group's backend.
            """
            backend = dist.get_backend(process_group)
            for _device, _backend in DeviceMesh._DIST_BACKEND.items():
                if _backend == backend:
                    return _device
            return None

        if isinstance(process_group, ProcessGroup):
            process_group = [process_group]

        # get mesh shape
        mesh_shape = [dist.get_world_size(pg) for pg in process_group]

        # get device
        device_list = [_get_device_by_backend(pg) for pg in process_group]

        # make sure all devices are the same
205
206
207
        assert all(
            [device == device_list[0] for device in device_list]
        ), "All devices should be the same, please check your input process groups are created with the same distributed backend."
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

        # create a fake physical mesh id
        # as we only get the process group associated with the current process,
        # we cannot get the global ranks for all processes in the mesh
        # therefore, we only use this fake physical mesh id to create the device mesh
        # and will remove this fake physical mesh id later
        fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1))

        # create the device mesh
        device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0])

        # hack the device attribute
        device_mesh._physical_mesh_id = None
        device_mesh._logical_mesh_id = None
        device_mesh._global_rank_of_current_process = dist.get_rank()
        device_mesh._is_initialized = False
        device_mesh._process_group_dict = {
            device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)}
        }

        return device_mesh

    def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
        """
        Return the process group on the specified axis.

        Args:
            axis (int): the axis of the process group.
            global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None)
        """
        if global_rank is None:
            global_rank = self._global_rank_of_current_process
        elif self._is_init_from_process_group:
            raise RuntimeError(
                "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
            )
        return self._process_group_dict[global_rank][axis]

    def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
        """
        Return the process groups for all axes.

        Args:
            global_rank (int, optional): the global rank of the process
        """
        if global_rank is None:
            global_rank = self._global_rank_of_current_process
        elif self._is_init_from_process_group:
            raise RuntimeError(
                "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
            )
        return self._process_group_dict[global_rank]

    def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
        """
        Return the ranks in the process group on the specified axis.

        Args:
            axis (int): the axis of the process group.
            global_rank (int, optional): the global rank of the process
        """
        if global_rank is None:
            global_rank = self._global_rank_of_current_process
        elif self._is_init_from_process_group:
            raise RuntimeError(
                "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
            )
        return self._ranks_in_the_process_group[global_rank][axis]

    def __deepcopy__(self, memo) -> "DeviceMesh":
278
279
280
281
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
282
            if k != "_process_group_dict":
283
284
                setattr(result, k, __import__("copy").deepcopy(v, memo))
            else:
285
286
                # process group cannot be copied
                # thus, we share them directly
287
288
289
                setattr(result, k, v)
        return result

290
291
292
    def _init_global_to_logical_rank_mapping(
        self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = []
    ) -> Dict[int, List[int]]:
293
        """
294
        Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
295

296
297
298
299
300
301
302
303
304
        Args:
            mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
            tensor (torch.Tensor): the tensor that contains the logical mesh ids.
            index_list (List[int])

        Returns:
            mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
                The value is a list of integers and each integer represents the local rank in the indexed axis.
        """
305
        for index, inner_tensor in enumerate(tensor):
306
307
308
            # index means the local rank in the current axis
            # inner_tensor refers to the processes with the same local rank

309
            if inner_tensor.numel() == 1:
310
311
312
313
314
315
                # if the inner_tensor only has one element, it means that
                # it already reaches the last axis
                # we append its local_rank in the last axis to the index_list
                # and assign to the mapping
                # the value of the mapping is the the local rank at the indexed axis of the device mesh
                mapping[int(inner_tensor)] = index_list + [index]
316
            else:
317
318
319
                # we recursively go into the function until we reach the last axis
                # meanwhile, we should add the local rank in the current axis in the index_list
                self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
320

321
    def init_logical_process_group(self):
322
        """
323
        This method is used to initialize the logical process groups which will be used in communications
324
        among logical device mesh.
325
326
        Note: if init_process_group set to False, you have to call this method manually. Otherwise,
        the communication related function, such as ShapeConsistencyManager.apply will raise errors.
327
        """
328
        # sanity check
329
330
331
332
333
334
        assert (
            dist.is_initialized
        ), "The torch.distributed should be initialized before calling init_logical_process_group"
        assert (
            not self._is_initialized
        ), "The logical process group has been initialized, do not call init_logical_process_group twice"
335
336
337
338
339
340
341
342

        # update the global rank of the current process
        self._global_rank_of_current_process = dist.get_rank()
        duplicate_check_list = []

        # flatten the global ranks to 1D list
        global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()

343
        for global_rank in global_rank_flatten_list:
344
345
            # find the other ranks which are in the same process group as global_rank
            ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
346

347
348
349
350
            for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
                # skip duplicated process group creation
                if ranks_in_same_group in duplicate_check_list:
                    continue
351

352
353
                # create the process group
                pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend)
354

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
                # keep this process group in the process_groups_dict
                for rank in ranks_in_same_group:
                    if rank not in self._process_group_dict:
                        self._process_group_dict[rank] = dict()
                    self._process_group_dict[rank][axis] = pg_handler

        # update the init flag
        # we only allow init for once
        self._is_initialized = True

    def _init_ranks_in_the_same_group(self):
        """
        This method is used to initialize the ranks_in_the_same_group dictionary.
        """
        # flatten the global ranks to 1D list
        global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()

        for global_rank in global_rank_flatten_list:
            # find the other ranks which are in the same process group as global_rank
            ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)

            for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
                # create dict for each rank
                if global_rank not in self._process_group_dict:
                    self._ranks_in_the_process_group[global_rank] = dict()

                # keep this process group in the process_groups_dict
                self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group

    def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]:
        """
        Return the local rank of the given global rank in the logical device mesh.

        Args:
            rank (int): the global rank in the logical device mesh.
            axis (int): the axis of the logical device mesh.
        """
        if self._is_init_from_process_group:
            raise RuntimeError(
                "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
            )

        local_ranks = self._global_to_local_rank_mapping[rank]
        if axis:
            return local_ranks[axis]
        else:
            return local_ranks

    def _collate_global_ranks_in_same_process_group(self, global_rank):
404
        """
405
406
407
408
409
        Give a global rank and return all global ranks involved in its associated process group in each axis.

        Example:

        ```python
410
        physical_mesh_id = torch.arange(0, 16)
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        mesh_shape = (4, 4)

        # logical mesh will look like
        # [[0, 1, 2, 3],
        #  [4, 5, 6, 7],
        #  [8, 9, 10,11],
        #  [12,13,14,15]]

        device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
        print(device_mesh.collate_global_ranks_in_same_process_group(0))

        # key is axis name
        # value is a list of global ranks in same axis with rank 0
        # output will look like
        # {
            0: [0, 4, 8, 12],
            1: [0, 1, 2, 3]
        #  }
429
        """
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
        # for self._global_to_local_rank_mapping
        # the key is the global rank
        # the value is the list of local ranks corresponding to the global rank with respect of different axes
        # we can see the list of local ranks as the process coordinates for simplicity
        # the key and value are all unique, therefore,
        # we can also to use the coordinates to find the global rank

        # =========================================================================
        # Step 1
        # find all the process_coordinates for processes in the same process group
        # as the given global rank
        # =========================================================================

        # each
        processes_in_the_same_process_group = {}

        for dim in range(self.logical_mesh_id.dim()):
            # iterate over the dimension size so that we can include all processes
            # in the same process group in the given axis
            # the _local_rank refers to the local rank of the current process
            for _local_rank in range(self.logical_mesh_id.shape[dim]):
452
                # if this dimension is not initialized yet,
453
454
455
456
457
458
459
460
                # initialize it with an empty array
                if dim not in processes_in_the_same_process_group:
                    processes_in_the_same_process_group[dim] = []

                # get the local rank corresponding to the global rank
                process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()

                # replace the local rank in the given dimension with the
461
                # local rank of the current process iterated
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
                process_coordinates[dim] = _local_rank
                processes_in_the_same_process_group[dim].append(process_coordinates)

        # =================================================================
        # Step 2
        # Use local rank combination to find its corresponding global rank
        # =================================================================
        # the key of the dict is the axis
        # the value is the list of global ranks which are in the same process group as the given global rank
        global_pg_ranks = {}
        for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items():
            global_pg_ranks[dim] = []
            for process_coordinates in coordinates_of_all_processes:
                # find the global rank by local rank combination
                for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items():
                    if process_coordinates == _process_coordinates:
                        global_pg_ranks[dim].append(_global_rank)
        return global_pg_ranks

    def flatten(self):
        """
        Flatten the logical mesh into an effective 1d logical mesh,
        """
        if self._is_init_from_process_group:
            raise RuntimeError(
                "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
            )

        flatten_mesh_shape_size = len(self._mesh_shape)
        flatten_mesh_shape = [self.num_devices]
492
493
494
495
496
497
498
        return DeviceMesh(
            self._physical_mesh_id,
            tuple(flatten_mesh_shape),
            mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
            mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
            init_process_group=self._init_process_group,
        )
499
500
501

    def all_gather_cost(self, num_bytes, mesh_dim):
        num_devices = self.logical_mesh_id.shape[mesh_dim]
502
        return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1
503
504
505

    def all_reduce_cost(self, num_bytes, mesh_dim):
        num_devices = self.logical_mesh_id.shape[mesh_dim]
506
507
508
509
510
        return (
            self.mesh_alpha[mesh_dim]
            + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes
            + 0.01
        )
511
512
513

    def reduce_scatter_cost(self, num_bytes, mesh_dim):
        num_devices = self.logical_mesh_id.shape[mesh_dim]
514
515
516
        return (
            self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001
        )
517
518
519
520

    def all_to_all_cost(self, num_bytes, mesh_dim):
        num_devices = self.logical_mesh_id.shape[mesh_dim]
        penalty_factor = num_devices / 2.0
521
522
523
524
525
        return (
            self.mesh_alpha[mesh_dim]
            + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor
            + 0.001
        )