from typing import List def generate_masked_orthogonal_rank_groups( world_size: int, parallel_size: List[int], mask: List[bool] ) -> List[List[int]]: """Generate orthogonal parallel groups based on the parallel size and mask. Arguments: world_size (int): world size parallel_size (List[int]): The parallel size of each orthogonal parallel type. For example, if tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. mask (List[bool]): The mask controls which parallel methods the generated groups represent. If mask[i] is True, it means the generated group contains the i-th parallelism method. For example, if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then the generated group is the `tp-dp` group, if the mask = [False, True, False], then the generated group is the `pp` group. Algorithm: For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and local_rank satisfy the following equation: global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1) tp_rank \in [0, tp_size) dp_rank \in [0, dp_size) pp_rank \in [0, pp_size) If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each. For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].) The tp_rank and pp_rank will be combined to form the `dp_group_index`. dp_group_index = tp_rank + pp_rank * tp_size (2) So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the equation (1). This function solve this math problem. For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4], and the mask = [False, True, False]. Then, dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2 dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2 ... dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2 dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4] dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5] ... dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23] """ def prefix_product(a: List[int], init=1) -> List[int]: r = [init] for v in a: init = init * v r.append(init) return r def inner_product(a: List[int], b: List[int]) -> int: return sum([x * y for x, y in zip(a, b)]) def decompose(index, shape, stride=None): """ This function solve the math problem below: There is an equation: index = sum(idx[i] * stride[i]) And given the value of index, stride. Return the idx. This function will used to get the pp/dp/pp_rank from group_index and rank_in_group. """ if stride is None: stride = prefix_product(shape) idx = [(index // d) % s for s, d in zip(shape, stride)] # stride is a prefix_product result. And the value of stride[-1] # is not used. assert ( sum([x * y for x, y in zip(idx, stride[:-1])]) == index ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) return idx masked_shape = [s for s, m in zip(parallel_size, mask) if m] unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] global_stride = prefix_product(parallel_size) masked_stride = [d for d, m in zip(global_stride, mask) if m] unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] group_size = prefix_product(masked_shape)[-1] num_of_group = world_size // group_size ranks = [] for group_index in range(num_of_group): # get indices from unmaksed for group_index. decomposed_group_idx = decompose(group_index, unmasked_shape) rank = [] for rank_in_group in range(group_size): # get indices from masked for rank_in_group. decomposed_rank_idx = decompose(rank_in_group, masked_shape) rank.append( inner_product(decomposed_rank_idx, masked_stride) + inner_product(decomposed_group_idx, unmasked_stride) ) ranks.append(rank) return ranks class RankGenerator(object): def __init__( self, tp: int, sp: int, pp: int, cfg: int, dp: int, order: str, rank_offset: int = 0, ) -> None: self.tp = tp self.sp = sp self.pp = pp self.cfg = cfg self.dp = dp self.rank_offset = rank_offset self.world_size = tp * sp * pp * cfg * dp self.name_to_size = { "tp": self.tp, "sp": self.sp, "pp": self.pp, "cfg": self.cfg, "dp": self.dp, } order = order.lower() for name in self.name_to_size.keys(): if name not in order and self.name_to_size[name] != 1: raise RuntimeError( f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." ) elif name not in order: order = order + "-" + name self.order = order self.ordered_size = [] for token in order.split("-"): self.ordered_size.append(self.name_to_size[token]) def get_mask(self, order: str, token: str): ordered_token = order.split("-") token = token.split("-") mask = [False] * len(ordered_token) for t in token: mask[ordered_token.index(t)] = True return mask def get_ranks(self, token): """Get rank group by input token. Arguments: token (str): Specify the ranks type that want to get. If we want to obtain multiple parallel types, we can use a hyphen '-' to separate them. For example, if we want to obtain the TP_DP group, the token should be 'tp-dp'. independent_ep (bool: True): This flag controls whether we treat EP and DP independently. EP shares ranks with DP, if we want to get ranks related to EP, we should set the flag. For example, get_ranks('dp', True) will get DP modulo EP group, and get_ranks('dp', False) will get full DP group. """ mask = self.get_mask(self.order, token) ranks = generate_masked_orthogonal_rank_groups( self.world_size, self.ordered_size, mask ) if self.rank_offset > 0: for rank_group in ranks: for i in range(len(rank_group)): rank_group[i] += self.rank_offset return ranks