dist_lookup.py 16.9 KB
Newer Older
1
import copy
2
import logging
3
import os
4

5
import numpy as np
6
7
import pyarrow
import torch
8
from gloo_wrapper import alltoallv_cpu
9

10
from pyarrow import csv
11
from utils import map_partid_rank
12

13

14
class DistLookupService:
15
    """
16
17
    This is an implementation of a Distributed Lookup Service to provide the following
    services to its users. Map 1) global node-ids to partition-ids, and 2) global node-ids
18
    to shuffle global node-ids (contiguous, within each node for a give node_type and across
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    all the partitions)

    This services initializes itself with the node-id to partition-id mappings, which are inputs
    to this service. The node-id to partition-id  mappings are assumed to be in one file for each
    node type. These node-id-to-partition-id mappings are split within the service processes so that
    each process ends up with a contiguous chunk. It first divides the no of mappings (node-id to
    partition-id) for each node type into equal chunks across all the service processes. So each
    service process will be thse owner of a set of node-id-to-partition-id mappings. This class
    has two functions which are as follows:

    1) `get_partition_ids` function which returns the node-id to partition-id mappings to the user
    2) `get_shuffle_nids` function which returns the node-id to shuffle-node-id mapping to the user

    Parameters:
    -----------
    input_dir : string
        string representing the input directory where the node-type partition-id
        files are located
    ntype_names : list of strings
        list of strings which are used to read files located within the input_dir
        directory and these files contents are partition-id's for the node-ids which
        are of a particular node type
    id_map : dgl.distributed.id_map instance
        this id_map is used to retrieve ntype-ids, node type ids, and type_nids, per type
        node ids, for any given global node id
    rank : integer
        integer indicating the rank of a given process
    world_size : integer
        integer indicating the total no. of processes
48
    """
49
50
51
52
53
54
55
56
57
58
59
60
61
62

    def __init__(self, input_dir, ntype_names, id_map, rank, world_size):
        assert os.path.isdir(input_dir)
        assert ntype_names is not None
        assert len(ntype_names) > 0

        # These lists are indexed by ntype_ids.
        type_nid_begin = []
        type_nid_end = []
        partid_list = []
        ntype_count = []

        # Iterate over the node types and extract the partition id mappings.
        for ntype in ntype_names:
63
64
65
66
67
68
69
70
71
72
73
            filename = f"{ntype}.txt"
            logging.info(
                f"[Rank: {rank}] Reading file: {os.path.join(input_dir, filename)}"
            )

            read_options = pyarrow.csv.ReadOptions(
                use_threads=True,
                block_size=4096,
                autogenerate_column_names=True,
            )
            parse_options = pyarrow.csv.ParseOptions(delimiter=" ")
74
            ntype_partids = []
75
76
77
78
79
            with pyarrow.csv.open_csv(
                os.path.join(input_dir, "{}.txt".format(ntype)),
                read_options=read_options,
                parse_options=parse_options,
            ) as reader:
80
81
82
83
                for next_chunk in reader:
                    if next_chunk is None:
                        break
                    next_table = pyarrow.Table.from_batches([next_chunk])
84
                    ntype_partids.append(next_table["f0"].to_numpy())
85
86

            ntype_partids = np.concatenate(ntype_partids)
87
88
89
90
91
            count = len(ntype_partids)
            ntype_count.append(count)

            # Each rank assumes a contiguous set of partition-ids which are equally split
            # across all the processes.
92
93
94
95
96
97
            split_size = np.ceil(count / np.int64(world_size)).astype(np.int64)
            start, end = (
                np.int64(rank) * split_size,
                np.int64(rank + 1) * split_size,
            )
            if rank == (world_size - 1):
98
99
100
                end = count
            type_nid_begin.append(start)
            type_nid_end.append(end)
101

102
            # Slice the partition-ids which belong to the current instance.
103
104
105
106
            partid_list.append(copy.deepcopy(ntype_partids[start:end]))

            # Explicitly release the array read from the file.
            del ntype_partids
107
108
109
110
111
112
113
114
115

        # Store all the information in the object instance variable.
        self.id_map = id_map
        self.type_nid_begin = np.array(type_nid_begin, dtype=np.int64)
        self.type_nid_end = np.array(type_nid_end, dtype=np.int64)
        self.partid_list = partid_list
        self.ntype_count = np.array(ntype_count, dtype=np.int64)
        self.rank = rank
        self.world_size = world_size
116

117
    def get_partition_ids(self, global_nids):
118
        """
119
120
        This function is used to get the partition-ids for a given set of global node ids

121
        global_nids <-> partition-ids mappings are deterministically  distributed across
122
        all the participating processes, within the service. A contiguous global-nids
123
        (ntype-ids, per-type-nids) are stored within each process and this is determined
124
125
126
127
128
129
130
131
        by the total no. of nodes of a given ntype-id and the rank of the process.

        Process, where the global_nid <-> partition-id mapping is stored can be easily computed
        as described above. Once this is determined we perform an alltoallv to send the request.
        On the receiving side, each process receives a set of global_nids and retrieves corresponding
        partition-ids using locally stored lookup tables. It builds responses to all the other
        processes and performs alltoallv.

132
        Once the response, partition-ids, is received, they are re-ordered corresponding to the
133
134
135
136
137
138
139
        incoming global-nids order and returns to the caller.

        Parameters:
        -----------
        self : instance of this class
            instance of this class, which is passed by the runtime implicitly
        global_nids : numpy array
140
            an array of global node-ids for which partition-ids are to be retrieved by
141
142
143
144
            the distributed lookup service.

        Returns:
        --------
145
        list of integers :
146
147
            list of integers, which are the partition-ids of the global-node-ids (which is the
            function argument)
148
        """
149

150
        # Find the process where global_nid --> partition-id(owner) is stored.
151
152
153
154
155
        ntype_ids, type_nids = self.id_map(global_nids)
        ntype_ids, type_nids = ntype_ids.numpy(), type_nids.numpy()
        assert len(ntype_ids) == len(global_nids)

        # For each node-type, the per-type-node-id <-> partition-id mappings are
156
        # stored as contiguous chunks by this lookup service.
157
158
159
        # The no. of these mappings stored by each process, in the lookup service, are
        # equally split among all the processes in the lookup service, deterministically.
        typeid_counts = self.ntype_count[ntype_ids]
160
161
162
163
        chunk_sizes = np.ceil(typeid_counts / self.world_size).astype(np.int64)
        service_owners = np.floor_divide(type_nids, chunk_sizes).astype(
            np.int64
        )
164
165
166
167

        # Now `service_owners` is a list of ranks (process-ids) which own the corresponding
        # global-nid <-> partition-id mapping.

168
        # Split the input global_nids into a list of lists where each list will be
169
170
171
172
173
174
175
176
177
178
179
        # sent to the respective rank/process
        # We also need to store the indices, in the indices_list, so that we can re-order
        # the final result (partition-ids) in the same order as the global-nids (function argument)
        send_list = []
        indices_list = []
        for idx in range(self.world_size):
            idxes = np.where(service_owners == idx)
            ll = global_nids[idxes[0]]
            send_list.append(torch.from_numpy(ll))
            indices_list.append(idxes[0])
        assert len(np.concatenate(indices_list)) == len(global_nids)
180
181
182
        assert np.all(
            np.sort(np.concatenate(indices_list)) == np.arange(len(global_nids))
        )
183
184
185

        # Send the request to everyone else.
        # As a result of this operation, the current process also receives a list of lists
186
187
        # from all the other processes.
        # These lists are global-node-ids whose global-node-ids <-> partition-id mappings
188
189
190
191
192
193
194
195
196
197
198
199
        # are owned/stored by the current process
        owner_req_list = alltoallv_cpu(self.rank, self.world_size, send_list)

        # Create the response list here for each of the request list received in the previous
        # step. Populate the respective partition-ids in this response lists appropriately
        out_list = []
        for idx in range(self.world_size):
            if owner_req_list[idx] is None:
                out_list.append(torch.empty((0,), dtype=torch.int64))
                continue
            # Get the node_type_ids and per_type_nids for the incoming global_nids.
            ntype_ids, type_nids = self.id_map(owner_req_list[idx].numpy())
200
            ntype_ids, type_nids = ntype_ids.numpy(), type_nids.numpy()
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

            # Lists to store partition-ids for the incoming global-nids.
            type_id_lookups = []
            local_order_idx = []

            # Now iterate over all the node_types and acculumulate all the partition-ids
            # since all the partition-ids are based on the node_type order... they
            # must be re-ordered as per the order of the input, which may be different.
            for tid in range(len(self.partid_list)):
                cond = ntype_ids == tid
                local_order_idx.append(np.where(cond)[0])
                global_type_nids = type_nids[cond]
                if len(global_type_nids) <= 0:
                    continue

                local_type_nids = global_type_nids - self.type_nid_begin[tid]

                assert np.all(local_type_nids >= 0)
219
220
221
222
                assert np.all(
                    local_type_nids
                    <= (self.type_nid_end[tid] + 1 - self.type_nid_begin[tid])
                )
223
224
225
226

                cur_owners = self.partid_list[tid][local_type_nids]
                type_id_lookups.append(cur_owners)

227
            # Reorder the partition-ids, so that it agrees with the input order --
228
229
230
231
232
233
234
235
236
237
238
239
            # which is the order in which the incoming message is received.
            if len(type_id_lookups) <= 0:
                out_list.append(torch.empty((0,), dtype=torch.int64))
            else:
                # Now reorder results for each request.
                sort_order_idx = np.argsort(np.concatenate(local_order_idx))
                lookups = np.concatenate(type_id_lookups)[sort_order_idx]
                out_list.append(torch.from_numpy(lookups))

        # Send the partition-ids to their respective requesting processes.
        owner_resp_list = alltoallv_cpu(self.rank, self.world_size, out_list)

240
        # Owner_resp_list, is a list of lists of numpy arrays where each list
241
        # is a list of partition-ids which the current process requested
242
        # Now we need to re-order so that the parition-ids correspond to the
243
244
        # global_nids which are passed into this function.

245
        # Order according to the requesting order.
246
        # Owner_resp_list is the list of owner-ids for global_nids (function argument).
247
248
249
        owner_ids = torch.cat(
            [x for x in owner_resp_list if x is not None]
        ).numpy()
250
251
252
253
254
255
256
257
258
259
260
        assert len(owner_ids) == len(global_nids)

        global_nids_order = np.concatenate(indices_list)
        sort_order_idx = np.argsort(global_nids_order)
        owner_ids = owner_ids[sort_order_idx]
        global_nids_order = global_nids_order[sort_order_idx]
        assert np.all(np.arange(len(global_nids)) == global_nids_order)

        # Now the owner_ids (partition-ids) which corresponding to the  global_nids.
        return owner_ids

261
262
263
264
    def get_shuffle_nids(
        self, global_nids, my_global_nids, my_shuffle_global_nids, world_size
    ):
        """
265
266
267
268
269
        This function is used to retrieve shuffle_global_nids for a given set of incoming
        global_nids. Note that global_nids are of random order and will contain duplicates

        This function first retrieves the partition-ids of the incoming global_nids.
        These partition-ids which are also the ranks of processes which own the respective
270
271
272
        global-nids as well as shuffle-global-nids. alltoallv is performed to send the
        global-nids to respective ranks/partition-ids where the mapping
        global-nids <-> shuffle-global-nid is located.
273
274
275
276
277
278
279
280
281
282
283
284
285

        On the receiving side, once the global-nids are received associated shuffle-global-nids
        are retrieved and an alltoallv is performed to send the responses to all the other
        processes.

        Once the responses, shuffle-global-nids, are received, they are re-ordered according
        to the incoming global-nids order and returns to the caller.

        Parameters:
        -----------
        self : instance of this class
            instance of this class, which is passed by the runtime implicitly
        global_nids : numpy array
286
            an array of global node-ids for which partition-ids are to be retrieved by
287
288
289
290
291
292
            the distributed lookup service.
        my_global_nids: numpy ndarray
            array of global_nids which are owned by the current partition/rank/process
            This process has the node <-> partition id mapping
        my_shuffle_global_nids : numpy ndarray
            array of shuffle_global_nids which are assigned by the current process/rank
293
294
        world_size : int
            total no. of processes in the MPI_WORLD
295
296
297
298
299
300

        Returns:
        --------
        list of integers:
            list of shuffle_global_nids which correspond to the incoming node-ids in the
            global_nids.
301
        """
302
303
304
305

        # Get the owner_ids (partition-ids or rank).
        owner_ids = self.get_partition_ids(global_nids)

306
        # These owner_ids, which are also partition ids of the nodes in the
307
308
        # input graph, are in the range 0 - (num_partitions - 1).
        # These ids are generated using some kind of graph partitioning method.
309
310
311
        # Distribuged lookup service, as used by the graph partitioning
        # pipeline, is used to store ntype-ids (also type_nids) and their
        # mapping to the associated partition-id.
312
313
314
315
316
317
318
319
320
        # These ids are split into `num_process` chunks and processes in the
        # dist. lookup service are assigned the owernship of these chunks.
        # The pipeline also enforeces the following constraint among the
        # pipeline input parameters: num_partitions, num_processes
        #   num_partitions is an integer multiple of num_processes
        #   which means each individual node in the cluster will be running
        #   equal number of processes.
        owner_ids = map_partid_rank(owner_ids, world_size)

321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        # Ask these owners to supply for the shuffle_global_nids.
        send_list = []
        id_list = []
        for idx in range(self.world_size):
            cond = owner_ids == idx
            idxes = np.where(cond)
            ll = global_nids[idxes[0]]
            send_list.append(torch.from_numpy(ll))
            id_list.append(idxes[0])

        assert len(np.concatenate(id_list)) == len(global_nids)
        cur_global_nids = alltoallv_cpu(self.rank, self.world_size, send_list)

        # At this point, current process received a list of lists each containing
        # a list of global-nids whose corresponding shuffle_global_nids are located
        # in the current process.
        shuffle_nids_list = []
        for idx in range(self.world_size):
            if cur_global_nids[idx] is None:
                shuffle_nids_list.append(torch.empty((0,), dtype=torch.int64))
                continue

343
344
345
346
347
348
349
350
351
            uniq_ids, inverse_idx = np.unique(
                cur_global_nids[idx], return_inverse=True
            )
            common, idx1, idx2 = np.intersect1d(
                uniq_ids,
                my_global_nids,
                assume_unique=True,
                return_indices=True,
            )
352
353
354
355
356
357
358
            assert len(common) == len(uniq_ids)

            req_shuffle_global_nids = my_shuffle_global_nids[idx2][inverse_idx]
            assert len(req_shuffle_global_nids) == len(cur_global_nids[idx])
            shuffle_nids_list.append(torch.from_numpy(req_shuffle_global_nids))

        # Send the shuffle-global-nids to their respective ranks.
359
360
361
        mapped_global_nids = alltoallv_cpu(
            self.rank, self.world_size, shuffle_nids_list
        )
362
363
364
        for idx in range(len(mapped_global_nids)):
            if mapped_global_nids[idx] == None:
                mapped_global_nids[idx] = torch.empty((0,), dtype=torch.int64)
365
366
367
368
369
370
371

        # Reorder to match global_nids (function parameter).
        global_nids_order = np.concatenate(id_list)
        shuffle_global_nids = torch.cat(mapped_global_nids).numpy()
        assert len(shuffle_global_nids) == len(global_nids)

        sorted_idx = np.argsort(global_nids_order)
372
        shuffle_global_nids = shuffle_global_nids[sorted_idx]
373
374
375
376
        global_nids_ordered = global_nids_order[sorted_idx]
        assert np.all(global_nids_ordered == np.arange(len(global_nids)))

        return shuffle_global_nids