data_shuffle.py 49 KB
Newer Older
1
2
3
import gc
import logging
import math
4
5
import os
import sys
6
7
8
from datetime import timedelta
from timeit import default_timer as timer

9
10
import constants

11
import dgl
12
13
14
15
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
16
from convert_partition import create_dgl_object, create_metadata_json
17
from dataset_utils import get_dataset
18
from dist_lookup import DistLookupService
19
20
21
22
23
from globalids import (
    assign_shuffle_global_nids_edges,
    assign_shuffle_global_nids_nodes,
    lookup_shuffle_global_nids_edges,
)
24
from gloo_wrapper import allgather_sizes, alltoallv_cpu, gather_metadata_json
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from utils import (
    augment_edge_data,
    get_edge_types,
    get_etype_featnames,
    get_gnid_range_map,
    get_idranges,
    get_node_types,
    get_ntype_featnames,
    map_partid_rank,
    memory_snapshot,
    read_json,
    read_ntype_partition_files,
    write_dgl_objects,
    write_metadata_json,
)


def gen_node_data(
    rank, world_size, num_parts, id_lookup, ntid_ntype_map, schema_map
):
    """
46
47
    For this data processing pipeline, reading node files is not needed. All the needed information about
    the nodes can be found in the metadata json file. This function generates the nodes owned by a given
48
    process, using metis partitions.
49

50
    Parameters:
51
52
    -----------
    rank : int
53
        rank of the process
54
    world_size : int
55
        total no. of processes
56
57
    num_parts : int
        total no. of partitions
58
    id_lookup : instance of class DistLookupService
59
       Distributed lookup service used to map global-nids to respective partition-ids and
60
       shuffle-global-nids
61
    ntid_ntype_map :
62
        a dictionary where keys are node_type ids(integers) and values are node_type names(strings).
63
    schema_map:
64
        dictionary formed by reading the input metadata json file for the input dataset.
65
66
67

        Please note that, it is assumed that for the input graph files, the nodes of a particular node-type are
        split into `p` files (because of `p` partitions to be generated). On a similar node, edges of a particular
68
        edge-type are split into `p` files as well.
69

70
71
        #assuming m nodetypes present in the input graph
        "num_nodes_per_chunk" : [
72
73
            [a0, a1, a2, ... a<p-1>],
            [b0, b1, b2, ... b<p-1>],
74
75
76
            ...
            [m0, m1, m2, ... m<p-1>]
        ]
77
        Here, each sub-list, corresponding a nodetype in the input graph, has `p` elements. For instance [a0, a1, ... a<p-1>]
78
79
80
        where each element represents the number of nodes which are to be processed by a process during distributed partitioning.

        In addition to the above key-value pair for the nodes in the graph, the node-features are captured in the
81
        "node_data" key-value pair. In this dictionary the keys will be nodetype names and value will be a dictionary which
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        is used to capture all the features present for that particular node-type. This is shown in the following example:

        "node_data" : {
            "paper": {       # node type
                "feat": {   # feature key
                    "format": {"name": "numpy"},
                    "data": ["node_data/paper-feat-part1.npy", "node_data/paper-feat-part2.npy"]
                },
                "label": {   # feature key
                    "format": {"name": "numpy"},
                    "data": ["node_data/paper-label-part1.npy", "node_data/paper-label-part2.npy"]
                },
                "year": {   # feature key
                    "format": {"name": "numpy"},
                    "data": ["node_data/paper-year-part1.npy", "node_data/paper-year-part2.npy"]
                }
            }
        }
100
101
102
        In the above textual description we have a node-type, which is paper, and it has 3 features namely feat, label and year.
        Each feature has `p` files whose location in the filesystem is the list for the key "data" and "foramt" is used to
        describe storage format.
103
104
105

    Returns:
    --------
106
107
    dictionary :
        dictionary where keys are column names and values are numpy arrays, these arrays are generated by
108
109
        using information present in the metadata json file

110
    """
111
    local_node_data = {}
112
113
114
115
116
117
    for local_part_id in range(num_parts // world_size):
        local_node_data[constants.GLOBAL_NID + "/" + str(local_part_id)] = []
        local_node_data[constants.NTYPE_ID + "/" + str(local_part_id)] = []
        local_node_data[
            constants.GLOBAL_TYPE_NID + "/" + str(local_part_id)
        ] = []
118
119

    # Note that `get_idranges` always returns two dictionaries. Keys in these
120
    # dictionaries are type names for nodes and edges and values are
121
    # `num_parts` number of tuples indicating the range of type-ids in first
122
123
124
125
126
127
    # dictionary and range of global-nids in the second dictionary.
    type_nid_dict, global_nid_dict = get_idranges(
        schema_map[constants.STR_NODE_TYPE],
        schema_map[constants.STR_NUM_NODES_PER_CHUNK],
        num_chunks=num_parts,
    )
128

129
    for ntype_id, ntype_name in ntid_ntype_map.items():
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        type_start, type_end = (
            type_nid_dict[ntype_name][0][0],
            type_nid_dict[ntype_name][-1][1],
        )
        gnid_start, gnid_end = (
            global_nid_dict[ntype_name][0, 0],
            global_nid_dict[ntype_name][0, 1],
        )

        node_partid_slice = id_lookup.get_partition_ids(
            np.arange(gnid_start, gnid_end, dtype=np.int64)
        )  # exclusive

        for local_part_id in range(num_parts // world_size):
            cond = node_partid_slice == (rank + local_part_id * world_size)
145
146
147
148
149
            own_gnids = np.arange(gnid_start, gnid_end, dtype=np.int64)
            own_gnids = own_gnids[cond]

            own_tnids = np.arange(type_start, type_end, dtype=np.int64)
            own_tnids = own_tnids[cond]
150

151
152
153
154
155
156
157
158
159
            local_node_data[
                constants.NTYPE_ID + "/" + str(local_part_id)
            ].append(np.ones(own_gnids.shape, dtype=np.int64) * ntype_id)
            local_node_data[
                constants.GLOBAL_NID + "/" + str(local_part_id)
            ].append(own_gnids)
            local_node_data[
                constants.GLOBAL_TYPE_NID + "/" + str(local_part_id)
            ].append(own_tnids)
160
161
162
163
164

    for k in local_node_data.keys():
        local_node_data[k] = np.concatenate(local_node_data[k])

    return local_node_data
165

166

167
def exchange_edge_data(rank, world_size, num_parts, edge_data):
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    """
    Exchange edge_data among processes in the world.
    Prepare list of sliced data targeting each process and trigger
    alltoallv_cpu to trigger messaging api

    Parameters:
    -----------
    rank : int
        rank of the process
    world_size : int
        total no. of processes
    edge_data : dictionary
        edge information, as a dicitonary which stores column names as keys and values
        as column data. This information is read from the edges.txt file.

    Returns:
    --------
185
    dictionary :
186
187
188
189
        the input argument, edge_data, is updated with the edge data received by other processes
        in the world.
    """

190
    # Prepare data for each rank in the cluster.
191
    start = timer()
192
    for local_part_id in range(num_parts // world_size):
193
194
195

        input_list = []
        for idx in range(world_size):
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
            send_idx = edge_data[constants.OWNER_PROCESS] == (
                idx + local_part_id * world_size
            )
            send_idx = send_idx.reshape(
                edge_data[constants.GLOBAL_SRC_ID].shape[0]
            )
            filt_data = np.column_stack(
                (
                    edge_data[constants.GLOBAL_SRC_ID][send_idx == 1],
                    edge_data[constants.GLOBAL_DST_ID][send_idx == 1],
                    edge_data[constants.GLOBAL_TYPE_EID][send_idx == 1],
                    edge_data[constants.ETYPE_ID][send_idx == 1],
                    edge_data[constants.GLOBAL_EID][send_idx == 1],
                )
            )
            if filt_data.shape[0] <= 0:
                input_list.append(torch.empty((0, 5), dtype=torch.int64))
213
214
215
            else:
                input_list.append(torch.from_numpy(filt_data))

216
217
218
219
        dist.barrier()
        output_list = alltoallv_cpu(
            rank, world_size, input_list, retain_nones=False
        )
220

221
        # Replace the values of the edge_data, with the received data from all the other processes.
222
        rcvd_edge_data = torch.cat(output_list).numpy()
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        edge_data[
            constants.GLOBAL_SRC_ID + "/" + str(local_part_id)
        ] = rcvd_edge_data[:, 0]
        edge_data[
            constants.GLOBAL_DST_ID + "/" + str(local_part_id)
        ] = rcvd_edge_data[:, 1]
        edge_data[
            constants.GLOBAL_TYPE_EID + "/" + str(local_part_id)
        ] = rcvd_edge_data[:, 2]
        edge_data[
            constants.ETYPE_ID + "/" + str(local_part_id)
        ] = rcvd_edge_data[:, 3]
        edge_data[
            constants.GLOBAL_EID + "/" + str(local_part_id)
        ] = rcvd_edge_data[:, 4]
238

239
    end = timer()
240
241
242
    logging.info(
        f"[Rank: {rank}] Time to send/rcv edge data: {timedelta(seconds=end-start)}"
    )
243

244
    # Clean up.
245
    edge_data.pop(constants.OWNER_PROCESS)
246
247
248
249
250
251
    edge_data.pop(constants.GLOBAL_SRC_ID)
    edge_data.pop(constants.GLOBAL_DST_ID)
    edge_data.pop(constants.GLOBAL_TYPE_EID)
    edge_data.pop(constants.ETYPE_ID)
    edge_data.pop(constants.GLOBAL_EID)

252
253
    return edge_data

254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

def exchange_feature(
    rank,
    data,
    id_lookup,
    feat_type,
    feat_key,
    featdata_key,
    gid_start,
    gid_end,
    type_id_start,
    type_id_end,
    local_part_id,
    world_size,
    num_parts,
    cur_features,
    cur_global_ids,
):
    """This function is used to send/receive one feature for either nodes or
273
274
275
276
277
    edges of the input graph dataset.

    Parameters:
    -----------
    rank : int
278
        integer, unique id assigned to the current process
279
    data: dicitonary
280
281
        dictionry in which node or edge features are stored and this information
        is read from the appropriate node features file which belongs to the
282
283
284
285
286
        current process
    id_lookup : instance of DistLookupService
        instance of an implementation of dist. lookup service to retrieve values
        for keys
    feat_type : string
287
288
289
        this is used to distinguish which features are being exchanged. Please
        note that for nodes ownership is clearly defined and for edges it is
        always assumed that destination end point of the edge defines the
290
291
        ownership of that particular edge
    feat_key : string
292
        this string is used as a key in the dictionary to store features, as
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        tensors, in local dictionaries
    featdata_key : numpy array
        features associated with this feature key being processed
    gid_start : int
        starting global_id, of either node or edge, for the feature data
    gid_end : int
        ending global_if, of either node or edge, for the feature data
    type_id_start : int
        starting type_id for the feature data
    type_id_end : int
        ending type_id for the feature data
    local_part_id : int
        integers used to the identify the local partition id used to locate
        data belonging to this partition
    world_size : int
        total number of processes created
    num_parts : int
        total number of partitions
    cur_features : dictionary
312
        dictionary to store the feature data which belongs to the current
313
314
        process
    cur_global_ids : dictionary
315
        dictionary to store global ids, of either nodes or edges, for which
316
        the features stored in the cur_features dictionary
317

318
319
320
    Returns:
    -------
    dictionary :
321
        a dictionary is returned where keys are type names and
322
323
        feature data are the values
    list :
324
        a dictionary of global_ids either nodes or edges whose features are
325
326
        received during the data shuffle process
    """
327
    # type_ids for this feature subset on the current rank
328
329
330
331
332
333
334
335
336
    gids_feat = np.arange(gid_start, gid_end)
    tids_feat = np.arange(type_id_start, type_id_end)
    local_idx = np.arange(0, type_id_end - type_id_start)

    feats_per_rank = []
    global_id_per_rank = []

    tokens = feat_key.split("/")
    assert len(tokens) == 3
337
    local_feat_key = "/".join(tokens[:-1]) + "/" + str(local_part_id)
338
339
340
341
342
    for idx in range(world_size):
        # Get the partition ids for the range of global nids.
        if feat_type == constants.STR_NODE_FEATURES:
            # Retrieve the partition ids for the node features.
            # Each partition id will be in the range [0, num_parts).
343
344
345
            partid_slice = id_lookup.get_partition_ids(
                np.arange(gid_start, gid_end, dtype=np.int64)
            )
346
        else:
347
348
            # Edge data case.
            # Ownership is determined by the destination node.
349
350
351
            assert data is not None
            global_eids = np.arange(gid_start, gid_end, dtype=np.int64)

352
353
354
355
356
            # Now use `data` to extract destination nodes' global id
            # and use that to get the ownership
            common, idx1, idx2 = np.intersect1d(
                data[constants.GLOBAL_EID], global_eids, return_indices=True
            )
357
358
359
360
361
362
            assert common.shape[0] == idx2.shape[0]

            global_dst_nids = data[constants.GLOBAL_DST_ID][idx1]
            assert np.all(global_eids == data[constants.GLOBAL_EID][idx1])
            partid_slice = id_lookup.get_partition_ids(global_dst_nids)

363
        cond = partid_slice == (idx + local_part_id * world_size)
364
365
366
367
        gids_per_partid = gids_feat[cond]
        tids_per_partid = tids_feat[cond]
        local_idx_partid = local_idx[cond]

368
369
370
        if gids_per_partid.shape[0] == 0:
            feats_per_rank.append(torch.empty((0, 1), dtype=torch.float))
            global_id_per_rank.append(torch.empty((0, 1), dtype=torch.int64))
371
372
        else:
            feats_per_rank.append(featdata_key[local_idx_partid])
373
374
375
376
377
378
379
380
381
382
383
384
            global_id_per_rank.append(
                torch.from_numpy(gids_per_partid).type(torch.int64)
            )

    # features (and global nids) per rank to be sent out are ready
    # for transmission, perform alltoallv here.
    output_feat_list = alltoallv_cpu(
        rank, world_size, feats_per_rank, retain_nones=False
    )
    output_id_list = alltoallv_cpu(
        rank, world_size, global_id_per_rank, retain_nones=False
    )
385
386
387
388
    assert len(output_feat_list) == len(output_id_list), (
        "Length of feature list and id list are expected to be equal while "
        f"got {len(output_feat_list)} and {len(output_id_list)}."
    )
389

390
    # stitch node_features together to form one large feature tensor
391
392
393
    if len(output_feat_list) > 0:
        output_feat_list = torch.cat(output_feat_list)
        output_id_list = torch.cat(output_id_list)
394
        if local_feat_key in cur_features:
395
396
397
398
399
400
401
            temp = cur_features[local_feat_key]
            cur_features[local_feat_key] = torch.cat([temp, output_feat_list])
            temp = cur_global_ids[local_feat_key]
            cur_global_ids[local_feat_key] = torch.cat([temp, output_id_list])
        else:
            cur_features[local_feat_key] = output_feat_list
            cur_global_ids[local_feat_key] = output_id_list
402
403
404
405

    return cur_features, cur_global_ids


406
407
408
409
410
411
412
413
414
415
416
def exchange_features(
    rank,
    world_size,
    num_parts,
    feature_tids,
    type_id_map,
    id_lookup,
    feature_data,
    feat_type,
    data,
):
417
418
    """
    This function is used to shuffle node features so that each process will receive
419
    all the node features whose corresponding nodes are owned by the same process.
420
421
    The mapping procedure to identify the owner process is not straight forward. The
    following steps are used to identify the owner processes for the locally read node-
422
    features.
423
424
    a. Compute the global_nids for the locally read node features. Here metadata json file
        is used to identify the corresponding global_nids. Please note that initial graph input
425
426
427
428
429
        nodes.txt files are sorted based on node_types.
    b. Using global_nids and metis partitions owner processes can be easily identified.
    c. Now each process sends the global_nids for which shuffle_global_nids are needed to be
        retrieved.
    d. After receiving the corresponding shuffle_global_nids these ids are added to the
430
        node_data and edge_data dictionaries
431

432
433
    This pipeline assumes all the input data in numpy format, except node/edge features which
    are maintained as tensors throughout the various stages of the pipeline execution.
434

435
    Parameters:
436
437
438
439
    -----------
    rank : int
        rank of the current process
    world_size : int
440
        total no. of participating processes.
441
    feature_tids : dictionary
442
        dictionary with keys as node-type names with suffixes as feature names
443
444
445
        and value is a dictionary. This dictionary contains information about
        node-features associated with a given node-type and value is a list.
        This list contains a of indexes, like [starting-idx, ending-idx) which
446
        can be used to index into the node feature tensors read from
447
448
        corresponding input files.
    type_id_map : dictionary
449
        mapping between type names and global_ids, of either nodes or edges,
450
        which belong to the keys in this dictionary
451
    id_lookup : instance of class DistLookupService
452
       Distributed lookup service used to map global-nids to respective
453
       partition-ids and shuffle-global-nids
454
    feat_type : string
455
456
457
        this is used to distinguish which features are being exchanged. Please
        note that for nodes ownership is clearly defined and for edges it is
        always assumed that destination end point of the edge defines the
458
459
460
        ownership of that particular edge
    data: dicitonary
        dictionry in which node or edge features are stored and this information
461
        is read from the appropriate node features file which belongs to the
462
        current process
463
464
465

    Returns:
    --------
466
    dictionary :
467
        a dictionary is returned where keys are type names and
468
469
        feature data are the values
    list :
470
        a dictionary of global_ids either nodes or edges whose features are
471
        received during the data shuffle process
472
473
    """
    start = timer()
474
    own_features = {}
475
476
477
478
479
480
481
482
    own_global_ids = {}

    # To iterate over the node_types and associated node_features
    for feat_key, type_info in feature_tids.items():

        # To iterate over the feature data, of a given (node or edge )type
        # type_info is a list of 3 elements (as shown below):
        #   [feature-name, starting-idx, ending-idx]
483
        #       feature-name is the name given to the feature-data,
484
        #       read from the input metadata file
485
486
        #       [starting-idx, ending-idx) specifies the range of indexes
        #        associated with the features data
487
488
        # Determine the owner process for these features.
        # Note that the keys in the node features (and similarly edge features)
489
        # dictionary is of the following format:
490
        #   `node_type/feature_name/local_part_id`:
491
492
493
        #    where node_type and feature_name are self-explanatory and
        #    local_part_id denotes the partition-id, in the local process,
        #    which will be used a suffix to store all the information of a
494
        #    given partition which is processed by the current process. Its
495
        #    values start from 0 onwards, for instance 0, 1, 2 ... etc.
496
        #    local_part_id can be easily mapped to global partition id very
497
        #    easily, using cyclic ordering. All local_part_ids = 0 from all
498
499
500
501
502
503
504
505
        #    processes will form global partition-ids between 0 and world_size-1.
        #    Similarly all local_part_ids = 1 from all processes will form
        #    global partition ids in the range [world_size, 2*world_size-1] and
        #    so on.
        tokens = feat_key.split("/")
        assert len(tokens) == 3
        type_name = tokens[0]
        feat_name = tokens[1]
506
        logging.info(f"[Rank: {rank}] processing feature: {feat_key}")
507

508
509
510
511
512
513
514
515
516
517
518
        for feat_info in type_info:
            # Compute the global_id range for this feature data
            type_id_start = int(feat_info[0])
            type_id_end = int(feat_info[1])
            begin_global_id = type_id_map[type_name][0]
            gid_start = begin_global_id + type_id_start
            gid_end = begin_global_id + type_id_end

            # Check if features exist for this type_name + feat_name.
            # This check should always pass, because feature_tids are built
            # by reading the input metadata json file for existing features.
519
            assert feat_key in feature_data
520

521
            for local_part_id in range(num_parts // world_size):
522
                featdata_key = feature_data[feat_key]
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
                own_features, own_global_ids = exchange_feature(
                    rank,
                    data,
                    id_lookup,
                    feat_type,
                    feat_key,
                    featdata_key,
                    gid_start,
                    gid_end,
                    type_id_start,
                    type_id_end,
                    local_part_id,
                    world_size,
                    num_parts,
                    own_features,
                    own_global_ids,
                )
540
541

    end = timer()
542
543
544
    logging.info(
        f"[Rank: {rank}] Total time for feature exchange: {timedelta(seconds = end - start)}"
    )
545
    return own_features, own_global_ids
546

547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563

def exchange_graph_data(
    rank,
    world_size,
    num_parts,
    node_features,
    edge_features,
    node_feat_tids,
    edge_feat_tids,
    edge_data,
    id_lookup,
    ntypes_ntypeid_map,
    ntypes_gnid_range_map,
    etypes_geid_range_map,
    ntid_ntype_map,
    schema_map,
):
564
    """
565
    Wrapper function which is used to shuffle graph data on all the processes.
566

567
    Parameters:
568
569
570
571
    -----------
    rank : int
        rank of the current process
    world_size : int
572
        total no. of participating processes.
573
574
    num_parts : int
        total no. of graph partitions.
575
    node_feautres : dicitonary
576
577
        dictionry where node_features are stored and this information is read from the appropriate
        node features file which belongs to the current process
578
579
580
    edge_features : dictionary
        dictionary where edge_features are stored. This information is read from the appropriate
        edge feature files whose ownership is assigned to the current process
581
582
583
584
585
    node_feat_tids: dictionary
        in which keys are node-type names and values are triplets. Each triplet has node-feature name
        and the starting and ending type ids of the node-feature data read from the corresponding
        node feature data file read by current process. Each node type may have several features and
        hence each key may have several triplets.
586
587
    edge_feat_tids : dictionary
        a dictionary in which keys are edge-type names and values are triplets of the format
588
        <feat-name, start-per-type-idx, end-per-type-idx>. This triplet is used to identify
589
        the chunk of feature data for which current process is responsible for
590
    edge_data : dictionary
591
        dictionary which is used to store edge information as read from appropriate files assigned
592
        to each process.
593
    id_lookup : instance of class DistLookupService
594
       Distributed lookup service used to map global-nids to respective partition-ids and
595
       shuffle-global-nids
596
    ntypes_ntypeid_map : dictionary
597
        mappings between node type names and node type ids
598
    ntypes_gnid_range_map : dictionary
599
        mapping between node type names and global_nids which belong to the keys in this dictionary
600
601
602
    etypes_geid_range_map : dictionary
        mapping between edge type names and global_eids which are assigned to the edges of this
        edge_type
603
    ntid_ntype_map : dictionary
604
        mapping between node type id and no of nodes which belong to each node_type_id
605
606
    schema_map : dictionary
        is the data structure read from the metadata json file for the input graph
607
608
609

    Returns:
    --------
610
    dictionary :
611
612
        the input argument, node_data dictionary, is updated with the node data received from other processes
        in the world. The node data is received by each rank in the process of data shuffling.
613
614
    dictionary :
        node features dictionary which has node features for the nodes which are owned by the current
615
        process
616
617
    dictionary :
        list of global_nids for the nodes whose node features are received when node features shuffling was
618
        performed in the `exchange_features` function call
619
    dictionary :
620
621
        the input argument, edge_data dictionary, is updated with the edge data received from other processes
        in the world. The edge data is received by each rank in the process of data shuffling.
622
    dictionary :
623
624
625
626
627
        edge features dictionary which has edge features. These destination end points of these edges
        are owned by the current process
    dictionary :
        list of global_eids for the edges whose edge features are received when edge features shuffling
        was performed in the `exchange_features` function call
628
    """
629
    memory_snapshot("ShuffleNodeFeaturesBegin: ", rank)
630
631
632
633
634
635
636
637
638
639
640
    rcvd_node_features, rcvd_global_nids = exchange_features(
        rank,
        world_size,
        num_parts,
        node_feat_tids,
        ntypes_gnid_range_map,
        id_lookup,
        node_features,
        constants.STR_NODE_FEATURES,
        None,
    )
641
    memory_snapshot("ShuffleNodeFeaturesComplete: ", rank)
642
643
644
645
646
647
648
649
650
651
652
653
654
655
    logging.info(f"[Rank: {rank}] Done with node features exchange.")

    rcvd_edge_features, rcvd_global_eids = exchange_features(
        rank,
        world_size,
        num_parts,
        edge_feat_tids,
        etypes_geid_range_map,
        id_lookup,
        edge_features,
        constants.STR_EDGE_FEATURES,
        edge_data,
    )
    logging.info(f"[Rank: {rank}] Done with edge features exchange.")
656

657
658
659
    node_data = gen_node_data(
        rank, world_size, num_parts, id_lookup, ntid_ntype_map, schema_map
    )
660
    memory_snapshot("NodeDataGenerationComplete: ", rank)
661

662
    edge_data = exchange_edge_data(rank, world_size, num_parts, edge_data)
663
    memory_snapshot("ShuffleEdgeDataComplete: ", rank)
664
665
666
667
668
669
670
671
672
    return (
        node_data,
        rcvd_node_features,
        rcvd_global_nids,
        edge_data,
        rcvd_edge_features,
        rcvd_global_eids,
    )

673

674
def read_dataset(rank, world_size, id_lookup, params, schema_map):
675
676
    """
    This function gets the dataset and performs post-processing on the data which is read from files.
677
    Additional information(columns) are added to nodes metadata like owner_process, global_nid which
678
679
    are later used in processing this information. For edge data, which is now a dictionary, we add new columns
    like global_edge_id and owner_process. Augmenting these data structure helps in processing these data structures
680
    when data shuffling is performed.
681
682
683
684
685

    Parameters:
    -----------
    rank : int
        rank of the current process
686
    world_size : int
687
        total no. of processes instantiated
688
    id_lookup : instance of class DistLookupService
689
       Distributed lookup service used to map global-nids to respective partition-ids and
690
       shuffle-global-nids
691
    params : argparser object
692
        argument parser object to access command line arguments
693
694
    schema_map : dictionary
        dictionary created by reading the input graph metadata json file
695

696
    Returns :
697
698
    ---------
    dictionary
699
700
        in which keys are node-type names and values are are tuples representing the range of ids
        for nodes to be read by the current process
701
702
    dictionary
        node features which is a dictionary where keys are feature names and values are feature
703
        data as multi-dimensional tensors
704
705
706
707
708
    dictionary
        in which keys are node-type names and values are triplets. Each triplet has node-feature name
        and the starting and ending type ids of the node-feature data read from the corresponding
        node feature data file read by current process. Each node type may have several features and
        hence each key may have several triplets.
709
    dictionary
710
711
        edge data information is read from edges.txt and additional columns are added such as
        owner process for each edge.
712
713
    dictionary
        edge features which is also a dictionary, similar to node features dictionary
714
715
716
717
    dictionary
        a dictionary in which keys are edge-type names and values are tuples indicating the range of ids
        for edges read by the current process.
    dictionary
718
        a dictionary in which keys are edge-type names and values are triplets,
719
720
        (edge-feature-name, start_type_id, end_type_id). These type_ids are indices in the edge-features
        read by the current process. Note that each edge-type may have several edge-features.
721
722
    """
    edge_features = {}
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
    # node_tids, node_features, edge_datadict, edge_tids
    (
        node_tids,
        node_features,
        node_feat_tids,
        edge_data,
        edge_tids,
        edge_features,
        edge_feat_tids,
    ) = get_dataset(
        params.input_dir,
        params.graph_name,
        rank,
        world_size,
        params.num_parts,
        schema_map,
    )
    logging.info(f"[Rank: {rank}] Done reading dataset {params.input_dir}")
741

742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
    edge_data = augment_edge_data(
        edge_data, id_lookup, edge_tids, rank, world_size, params.num_parts
    )
    logging.info(
        f"[Rank: {rank}] Done augmenting edge_data: {len(edge_data)}, {edge_data[constants.GLOBAL_SRC_ID].shape}"
    )

    return (
        node_tids,
        node_features,
        node_feat_tids,
        edge_data,
        edge_features,
        edge_tids,
        edge_feat_tids,
    )
758
759
760
761


def gen_dist_partitions(rank, world_size, params):
    """
762
763
    Function which will be executed by all Gloo processes to begin execution of the pipeline.
    This function expects the input dataset is split across multiple file format.
764

765
    Input dataset and its file structure is described in metadata json file which is also part of the
766
767
    input dataset. On a high-level, this metadata json file contains information about the following items
    a) Nodes metadata, It is assumed that nodes which belong to each node-type are split into p files
768
769
       (wherer `p` is no. of partitions).
    b) Similarly edge metadata contains information about edges which are split into p-files.
770
771
772
773
774
    c) Node and Edge features, it is also assumed that each node (and edge) feature, if present, is also
       split into `p` files.

    For example, a sample metadata json file might be as follows: :
    (In this toy example, we assume that we have "m" node-types, "k" edge types, and for node_type = ntype0-name
775
     we have two features namely feat0-name and feat1-name. Please note that the node-features are also split into
776
777
778
779
     `p` files. This will help in load-balancing during data-shuffling phase).

    Terminology used to identify any particular "id" assigned to nodes, edges or node features. Prefix "global" is
    used to indicate that this information is either read from the input dataset or autogenerated based on the information
780
781
782
783
    read from input dataset files. Prefix "type" is used to indicate a unique id assigned to either nodes or edges.
    For instance, type_node_id means that a unique id, with a given node type,  assigned to a node. And prefix "shuffle"
    will be used to indicate a unique id, across entire graph, assigned to either a node or an edge. For instance,
    SHUFFLE_GLOBAL_NID means a unique id which is assigned to a node after the data shuffle is completed.
784

785
786
    Some high-level notes on the structure of the metadata json file.
    1. path(s) mentioned in the entries for nodes, edges and node-features files can be either absolute or relative.
787
       if these paths are relative, then it is assumed that they are relative to the folder from which the execution is
788
789
790
791
       launched.
    2. The id_startx and id_endx represent the type_node_id and type_edge_id respectively for nodes and edge data. This
       means that these ids should match the no. of nodes/edges read from any given file. Since these are type_ids for
       the nodes and edges in any given file, their global_ids can be easily computed as well.
792
793

    {
794
795
796
797
        "graph_name" : xyz,
        "node_type" : ["ntype0-name", "ntype1-name", ....], #m node types
        "num_nodes_per_chunk" : [
            [a0, a1, ...a<p-1>], #p partitions
798
            [b0, b1, ... b<p-1>],
799
800
801
802
803
804
            ....
            [c0, c1, ..., c<p-1>] #no, of node types
        ],
        "edge_type" : ["src_ntype:edge_type:dst_ntype", ....], #k edge types
        "num_edges_per_chunk" : [
            [a0, a1, ...a<p-1>], #p partitions
805
            [b0, b1, ... b<p-1>],
806
807
808
            ....
            [c0, c1, ..., c<p-1>] #no, of edge types
        ],
809
810
        "node_data" : {
            "ntype0-name" : {
811
812
813
814
815
816
                "feat0-name" : {
                    "format" : {"name": "numpy"},
                    "data" :   [ #list of lists
                        ["<path>/feat-0.npy", 0, id_end0],
                        ["<path>/feat-1.npy", id_start1, id_end1],
                        ....
817
                        ["<path>/feat-<p-1>.npy", id_start<p-1>, id_end<p-1>]
818
819
820
                    ]
                },
                "feat1-name" : {
821
                    "format" : {"name": "numpy"},
822
823
824
825
                    "data" : [ #list of lists
                        ["<path>/feat-0.npy", 0, id_end0],
                        ["<path>/feat-1.npy", id_start1, id_end1],
                        ....
826
                        ["<path>/feat-<p-1>.npy", id_start<p-1>, id_end<p-1>]
827
828
                    ]
                }
829
830
            }
        },
831
        "edges": { #k edge types
832
            "src_ntype:etype0-name:dst_ntype" : {
833
                "format": {"name" : "csv", "delimiter" : " "},
834
835
836
837
838
839
                "data" : [
                    ["<path>/etype0-name-0.txt", 0, id_end0], #These are type_edge_ids for edges of this type
                    ["<path>/etype0-name-1.txt", id_start1, id_end1],
                    ...,
                    ["<path>/etype0-name-<p-1>.txt", id_start<p-1>, id_end<p-1>]
                ]
840
841
            },
            ...,
842
            "src_ntype:etype<k-1>-name:dst_ntype" : {
843
                "format": {"name" : "csv", "delimiter" : " "},
844
845
846
847
848
849
                "data" : [
                    ["<path>/etype<k-1>-name-0.txt", 0, id_end0],
                    ["<path>/etype<k-1>-name-1.txt", id_start1, id_end1],
                    ...,
                    ["<path>/etype<k-1>-name-<p-1>.txt", id_start<p-1>, id_end<p-1>]
                ]
850
851
            },
        },
852
    }
853

854
    The function performs the following steps:
855
    1. Reads the metis partitions to identify the owner process of all the nodes in the entire graph.
856
    2. Reads the input data set, each partitipating process will map to a single file for the edges,
857
858
859
        node-features and edge-features for each node-type and edge-types respectively. Using nodes metadata
        information, nodes which are owned by a given process are generated to optimize communication to some
        extent.
860
    3. Now each process shuffles the data by identifying the respective owner processes using metis
861
862
863
864
        partitions.
        a. To identify owner processes for nodes, metis partitions will be used.
        b. For edges, the owner process of the destination node will be the owner of the edge as well.
        c. For node and edge features, identifying the owner process is a little bit involved.
865
866
            For this purpose, graph metadata json file is used to first map the locally read node features
            to their global_nids. Now owner process is identified using metis partitions for these global_nids
867
868
869
870
871
872
873
            to retrieve shuffle_global_nids. A similar process is used for edge_features as well.
        d. After all the data shuffling is done, the order of node-features may be different when compared to
            their global_type_nids. Node- and edge-data are ordered by node-type and edge-type respectively.
            And now node features and edge features are re-ordered to match the order of their node- and edge-types.
    4. Last step is to create the DGL objects with the data present on each of the processes.
        a. DGL objects for nodes, edges, node- and edge- features.
        b. Metadata is gathered from each process to create the global metadata json file, by process rank = 0.
874
875
876
877
878
879
880
881
882
883
884

    Parameters:
    ----------
    rank : int
        integer representing the rank of the current process in a typical distributed implementation
    world_size : int
        integer representing the total no. of participating processes in a typical distributed implementation
    params : argparser object
        this object, key value pairs, provides access to the command line arguments from the runtime environment
    """
    global_start = timer()
885
886
887
    logging.info(
        f"[Rank: {rank}] Starting distributed data processing pipeline..."
    )
888
    memory_snapshot("Pipeline Begin: ", rank)
889
    # init processing
890
891
    schema_map = read_json(os.path.join(params.input_dir, params.schema))

892
893
894
895
896
897
898
    # Initialize distributed lookup service for partition-id and shuffle-global-nids mappings
    # for global-nids
    _, global_nid_ranges = get_idranges(
        schema_map[constants.STR_NODE_TYPE],
        schema_map[constants.STR_NUM_NODES_PER_CHUNK],
        params.num_parts,
    )
899
    id_map = dgl.distributed.id_map.IdMap(global_nid_ranges)
900
901
902
903

    # The resources, which are node-id to partition-id mappings, are split
    # into `world_size` number of parts, where each part can be mapped to
    # each physical node.
904
905
906
907
908
909
910
    id_lookup = DistLookupService(
        os.path.join(params.input_dir, params.partitions_dir),
        schema_map[constants.STR_NODE_TYPE],
        id_map,
        rank,
        world_size,
    )
911
912

    ntypes_ntypeid_map, ntypes, ntypeid_ntypes_map = get_node_types(schema_map)
913
    etypes_etypeid_map, etypes, etypeid_etypes_map = get_edge_types(schema_map)
914
915
916
    logging.info(
        f"[Rank: {rank}] Initialized metis partitions and node_types map..."
    )
917

918
919
920
921
922
923
924
925
926
927
928
929
930
931
    # read input graph files and augment these datastructures with
    # appropriate information (global_nid and owner process) for node and edge data
    (
        node_tids,
        node_features,
        node_feat_tids,
        edge_data,
        edge_features,
        edge_tids,
        edge_feat_tids,
    ) = read_dataset(rank, world_size, id_lookup, params, schema_map)
    logging.info(
        f"[Rank: {rank}] Done augmenting file input data with auxilary columns"
    )
932
    memory_snapshot("DatasetReadComplete: ", rank)
933

934
935
936
    # send out node and edge data --- and appropriate features.
    # this function will also stitch the data recvd from other processes
    # and return the aggregated data
937
    ntypes_gnid_range_map = get_gnid_range_map(node_tids)
938
    etypes_geid_range_map = get_gnid_range_map(edge_tids)
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
    (
        node_data,
        rcvd_node_features,
        rcvd_global_nids,
        edge_data,
        rcvd_edge_features,
        rcvd_global_eids,
    ) = exchange_graph_data(
        rank,
        world_size,
        params.num_parts,
        node_features,
        edge_features,
        node_feat_tids,
        edge_feat_tids,
        edge_data,
        id_lookup,
        ntypes_ntypeid_map,
        ntypes_gnid_range_map,
        etypes_geid_range_map,
        ntypeid_ntypes_map,
        schema_map,
    )
962
    gc.collect()
963
    logging.info(f"[Rank: {rank}] Done with data shuffling...")
964
    memory_snapshot("DataShuffleComplete: ", rank)
965

966
967
968
    # sort node_data by ntype
    for local_part_id in range(params.num_parts // world_size):
        idx = node_data[constants.NTYPE_ID + "/" + str(local_part_id)].argsort()
969
970
971
972
973
974
        for k, v in node_data.items():
            tokens = k.split("/")
            assert len(tokens) == 2
            if tokens[1] == str(local_part_id):
                node_data[k] = v[idx]
        idx = None
975
    gc.collect()
976
    logging.info(f"[Rank: {rank}] Sorted node_data by node_type")
977

978
979
980
981
982
    # resolve global_ids for nodes
    assign_shuffle_global_nids_nodes(
        rank, world_size, params.num_parts, node_data
    )
    logging.info(f"[Rank: {rank}] Done assigning global-ids to nodes...")
983
    memory_snapshot("ShuffleGlobalID_Nodes_Complete: ", rank)
984

985
    # shuffle node feature according to the node order on each rank.
986
987
988
    for ntype_name in ntypes:
        featnames = get_ntype_featnames(ntype_name, schema_map)
        for featname in featnames:
989
990
991
992
993
994
995
            # if a feature name exists for a node-type, then it should also have
            # feature data as well. Hence using the assert statement.
            for local_part_id in range(params.num_parts // world_size):
                feature_key = (
                    ntype_name + "/" + featname + "/" + str(local_part_id)
                )
                assert feature_key in rcvd_global_nids
996
                global_nids = rcvd_global_nids[feature_key]
997

998
999
1000
1001
1002
1003
1004
1005
                _, idx1, _ = np.intersect1d(
                    node_data[constants.GLOBAL_NID + "/" + str(local_part_id)],
                    global_nids,
                    return_indices=True,
                )
                shuffle_global_ids = node_data[
                    constants.SHUFFLE_GLOBAL_NID + "/" + str(local_part_id)
                ][idx1]
1006
                feature_idx = shuffle_global_ids.argsort()
1007

1008
1009
1010
                rcvd_node_features[feature_key] = rcvd_node_features[
                    feature_key
                ][feature_idx]
1011
    memory_snapshot("ReorderNodeFeaturesComplete: ", rank)
1012

1013
1014
1015
1016
1017
    # sort edge_data by etype
    for local_part_id in range(params.num_parts // world_size):
        sorted_idx = edge_data[
            constants.ETYPE_ID + "/" + str(local_part_id)
        ].argsort()
1018
1019
1020
1021
1022
1023
        for k, v in edge_data.items():
            tokens = k.split("/")
            assert len(tokens) == 2
            if tokens[1] == str(local_part_id):
                edge_data[k] = v[sorted_idx]
        sorted_idx = None
1024
    gc.collect()
1025

1026
1027
1028
1029
    shuffle_global_eid_offsets = assign_shuffle_global_nids_edges(
        rank, world_size, params.num_parts, edge_data
    )
    logging.info(f"[Rank: {rank}] Done assigning global_ids to edges ...")
1030
    memory_snapshot("ShuffleGlobalID_Edges_Complete: ", rank)
1031

1032
    # Shuffle edge features according to the edge order on each rank.
1033
1034
1035
    for etype_name in etypes:
        featnames = get_etype_featnames(etype_name, schema_map)
        for featname in featnames:
1036
1037
1038
1039
            for local_part_id in range(params.num_parts // world_size):
                feature_key = (
                    etype_name + "/" + featname + "/" + str(local_part_id)
                )
1040
1041
                assert feature_key in rcvd_global_eids
                global_eids = rcvd_global_eids[feature_key]
1042

1043
1044
1045
1046
1047
1048
1049
1050
                _, idx1, _ = np.intersect1d(
                    edge_data[constants.GLOBAL_EID + "/" + str(local_part_id)],
                    global_eids,
                    return_indices=True,
                )
                shuffle_global_ids = edge_data[
                    constants.SHUFFLE_GLOBAL_EID + "/" + str(local_part_id)
                ][idx1]
1051
                feature_idx = shuffle_global_ids.argsort()
1052

1053
1054
1055
                rcvd_edge_features[feature_key] = rcvd_edge_features[
                    feature_key
                ][feature_idx]
1056

1057
1058
1059
1060
1061
1062
1063
    # determine global-ids for edge end-points
    edge_data = lookup_shuffle_global_nids_edges(
        rank, world_size, params.num_parts, edge_data, id_lookup, node_data
    )
    logging.info(
        f"[Rank: {rank}] Done resolving orig_node_id for local node_ids..."
    )
1064
    memory_snapshot("ShuffleGlobalID_Lookup_Complete: ", rank)
1065

1066
1067
1068
1069
    def prepare_local_data(src_data, local_part_id):
        local_data = {}
        for k, v in src_data.items():
            tokens = k.split("/")
1070
            if tokens[len(tokens) - 1] == str(local_part_id):
1071
1072
1073
                local_data["/".join(tokens[:-1])] = v
        return local_data

1074
    # create dgl objects here
1075
    output_meta_json = {}
1076
    start = timer()
1077

1078
1079
    graph_formats = None
    if params.graph_formats:
1080
1081
1082
        graph_formats = params.graph_formats.split(",")

    for local_part_id in range(params.num_parts // world_size):
1083
        num_edges = shuffle_global_eid_offsets[local_part_id]
1084
1085
1086
1087
1088
1089
        node_count = len(
            node_data[constants.NTYPE_ID + "/" + str(local_part_id)]
        )
        edge_count = len(
            edge_data[constants.ETYPE_ID + "/" + str(local_part_id)]
        )
1090
1091
        local_node_data = prepare_local_data(node_data, local_part_id)
        local_edge_data = prepare_local_data(edge_data, local_part_id)
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
        (
            graph_obj,
            ntypes_map_val,
            etypes_map_val,
            ntypes_map,
            etypes_map,
            orig_nids,
            orig_eids,
        ) = create_dgl_object(
            schema_map,
            rank + local_part_id * world_size,
            local_node_data,
            local_edge_data,
            num_edges,
            params.save_orig_nids,
            params.save_orig_eids,
        )
1109
        sort_etypes = len(etypes_map) > 1
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
        local_node_features = prepare_local_data(
            rcvd_node_features, local_part_id
        )
        local_edge_features = prepare_local_data(
            rcvd_edge_features, local_part_id
        )
        write_dgl_objects(
            graph_obj,
            local_node_features,
            local_edge_features,
            params.output,
            rank + (local_part_id * world_size),
            orig_nids,
            orig_eids,
            graph_formats,
            sort_etypes,
        )
1127
1128
        memory_snapshot("DiskWriteDGLObjectsComplete: ", rank)

1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
        # get the meta-data
        json_metadata = create_metadata_json(
            params.graph_name,
            node_count,
            edge_count,
            local_part_id * world_size + rank,
            params.num_parts,
            ntypes_map_val,
            etypes_map_val,
            ntypes_map,
            etypes_map,
            params.output,
        )
        output_meta_json[
            "local-part-id-" + str(local_part_id * world_size + rank)
        ] = json_metadata
1145
        memory_snapshot("MetadataCreateComplete: ", rank)
1146

1147
1148
    if rank == 0:
        # get meta-data from all partitions and merge them on rank-0
1149
1150
        metadata_list = gather_metadata_json(output_meta_json, rank, world_size)
        metadata_list[0] = output_meta_json
1151
1152
1153
1154
1155
1156
1157
        write_metadata_json(
            metadata_list,
            params.output,
            params.graph_name,
            world_size,
            params.num_parts,
        )
1158
    else:
1159
        # send meta-data to Rank-0 process
1160
        gather_metadata_json(output_meta_json, rank, world_size)
1161
    end = timer()
1162
1163
1164
    logging.info(
        f"[Rank: {rank}] Time to create dgl objects: {timedelta(seconds = end - start)}"
    )
1165
    memory_snapshot("MetadataWriteComplete: ", rank)
1166
1167

    global_end = timer()
1168
1169
1170
    logging.info(
        f"[Rank: {rank}] Total execution time of the program: {timedelta(seconds = global_end - global_start)}"
    )
1171
    memory_snapshot("PipelineComplete: ", rank)
1172

1173

1174
def single_machine_run(params):
1175
    """Main function for distributed implementation on a single machine
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186

    Parameters:
    -----------
    params : argparser object
        Argument Parser structure with pre-determined arguments as defined
        at the bottom of this file.
    """
    log_params(params)
    processes = []
    mp.set_start_method("spawn")

1187
1188
    # Invoke `target` function from each of the spawned process for distributed
    # implementation
1189
    for rank in range(params.world_size):
1190
1191
1192
1193
        p = mp.Process(
            target=run,
            args=(rank, params.world_size, gen_dist_partitions, params),
        )
1194
1195
1196
1197
1198
1199
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

1200

1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
def run(rank, world_size, func_exec, params, backend="gloo"):
    """
    Init. function which is run by each process in the Gloo ProcessGroup

    Parameters:
    -----------
    rank : integer
        rank of the process
    world_size : integer
        number of processes configured in the Process Group
    proc_exec : function name
        function which will be invoked which has the logic for each process in the group
    params : argparser object
        argument parser object to access the command line arguments
    backend : string
        string specifying the type of backend to use for communication
    """
1218
1219
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
1220

1221
1222
1223
1224
1225
1226
1227
    # create Gloo Process Group
    dist.init_process_group(
        backend,
        rank=rank,
        world_size=world_size,
        timeout=timedelta(seconds=5 * 60),
    )
1228

1229
    # Invoke the main function to kick-off each process
1230
1231
    func_exec(rank, world_size, params)

1232

1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
def multi_machine_run(params):
    """
    Function to be invoked when executing data loading pipeline on multiple machines

    Parameters:
    -----------
    params : argparser object
        argparser object providing access to command line arguments.
    """
    rank = int(os.environ["RANK"])

1244
    # init the gloo process group here.
1245
    dist.init_process_group(
1246
1247
1248
1249
1250
1251
        backend="gloo",
        rank=rank,
        world_size=params.world_size,
        timeout=timedelta(seconds=params.process_group_timeout),
    )
    logging.info(f"[Rank: {rank}] Done with process group initialization...")
1252

1253
    # invoke the main function here.
1254
    gen_dist_partitions(rank, params.world_size, params)
1255
1256
1257
    logging.info(
        f"[Rank: {rank}] Done with Distributed data processing pipeline processing."
    )