dataset_utils.py 27.1 KB
Newer Older
1
import gc
2
import logging
3
import os
4
5
6

import array_readwriter
import constants
7

8
import numpy as np
9
import pyarrow
10
import pyarrow.parquet as pq
11
import torch
12
13
import torch.distributed as dist
from gloo_wrapper import alltoallv_cpu
14
from utils import (
15
    DATA_TYPE_ID,
16
17
18
19
    generate_read_list,
    get_gid_offsets,
    get_idranges,
    map_partid_rank,
20
    REV_DATA_TYPE_ID,
21
)
22
23


24
25
26
27
28
29
def _broadcast_shape(
    data, rank, world_size, num_parts, is_feat_data, feat_name
):
    """Auxiliary function to broadcast the shape of a feature data.
    This information is used to figure out the type-ids for the
    local features.
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    Parameters:
    -----------
    data : numpy array
        which is the feature data read from the disk
    rank : integer
        which represents the id of the process in the process group
    world_size : integer
        represents the total no. of process in the process group
    num_parts : integer
        specifying the no. of partitions
    is_feat_data : bool
        flag used to seperate feature data and edge data
    feat_name : string
        name of the feature

    Returns:
47
    -------
48
49
    list of tuples :
        which represents the range of type-ids for the data array.
50
51
52
53
54
    """
    assert len(data.shape) in [
        1,
        2,
    ], f"Data is expected to be 1-D or 2-D but got {data.shape}."
55
    data_shape = list(data.shape)
56

57
58
59
    if len(data_shape) == 1:
        data_shape.append(1)

60
61
62
63
    if is_feat_data:
        data_shape.append(DATA_TYPE_ID[data.dtype])

    data_shape = torch.tensor(data_shape, dtype=torch.int64)
64
65
66
67
    data_shape_output = [
        torch.zeros_like(data_shape) for _ in range(world_size)
    ]
    dist.all_gather(data_shape_output, data_shape)
68
    logging.debug(
69
70
71
72
        f"[Rank: {rank} Received shapes from all ranks: {data_shape_output}"
    )
    shapes = [x.numpy() for x in data_shape_output if x[0] != 0]
    shapes = np.vstack(shapes)
73

74
    if is_feat_data:
75
        logging.debug(
76
77
78
79
80
            f"shapes: {shapes}, condition: {all(shapes[0,2] == s for s in shapes[:,2])}"
        )
        assert all(
            shapes[0, 2] == s for s in shapes[:, 2]
        ), f"dtypes for {feat_name} does not match on all ranks"
81

82
83
84
85
86
    # compute tids here.
    type_counts = list(shapes[:, 0])
    tid_start = np.cumsum([0] + type_counts[:-1])
    tid_end = np.cumsum(type_counts)
    tid_ranges = list(zip(tid_start, tid_end))
87
    logging.debug(f"starts -> {tid_start} ... end -> {tid_end}")
88

89
    return tid_ranges
90

91

92
93
94
def get_dataset(
    input_dir, graph_name, rank, world_size, num_parts, schema_map, ntype_counts
):
95
    """
96
    Function to read the multiple file formatted dataset.
97
98
99
100
101
102
103
104
105

    Parameters:
    -----------
    input_dir : string
        root directory where dataset is located.
    graph_name : string
        graph name string
    rank : int
        rank of the current process
106
107
    world_size : int
        total number of process in the current execution
108
109
    num_parts : int
        total number of output graph partitions
110
111
112
    schema_map : dictionary
        this is the dictionary created by reading the graph metadata json file
        for the input graph dataset
113
114
115
116

    Return:
    -------
    dictionary
117
118
119
120
        where keys are node-type names and values are tuples. Each tuple represents the
        range of type ids read from a file by the current process. Please note that node
        data for each node type is split into "p" files and each one of these "p" files are
        read a process in the distributed graph partitioning pipeline
121
    dictionary
122
123
        Data read from numpy files for all the node features in this dataset. Dictionary built
        using this data has keys as node feature names and values as tensor data representing
124
        node features
125
    dictionary
126
        in which keys are node-type and values are a triplet. This triplet has node-feature name,
127
128
        and range of tids for the node feature data read from files by the current process. Each
        node-type may have mutiple feature(s) and associated tensor data.
129
    dictionary
130
131
        Data read from edges.txt file and used to build a dictionary with keys as column names
        and values as columns in the csv file.
132
    dictionary
133
        in which keys are edge-type names and values are triplets. This triplet has edge-feature name,
134
135
        and range of tids for theedge feature data read from the files by the current process. Each
        edge-type may have several edge features and associated tensor data.
136
137
138
139
140
141
    dictionary
        Data read from numpy files for all the edge features in this dataset. This dictionary's keys
        are feature names and values are tensors data representing edge feature data.
    dictionary
        This dictionary is used for identifying the global-id range for the associated edge features
        present in the previous return value. The keys are edge-type names and values are triplets.
142
        Each triplet consists of edge-feature name and starting and ending points of the range of
143
        tids representing the corresponding edge feautres.
144
    """
145

146
147
148
149
150
151
152
    # node features dictionary
    # TODO: With the new file format, It is guaranteed that the input dataset will have
    # no. of nodes with features (node-features) files and nodes metadata will always be the same.
    # This means the dimension indicating the no. of nodes in any node-feature files and the no. of
    # nodes in the corresponding nodes metadata file will always be the same. With this guarantee,
    # we can eliminate the `node_feature_tids` dictionary since the same information is also populated
    # in the `node_tids` dictionary. This will be remnoved in the next iteration of code changes.
153
    node_features = {}
154
    node_feature_tids = {}
155
156

    """
157
    The structure of the node_data is as follows, which is present in the input metadata json file.
158
159
160
161
162
163
164
165
166
167
168
169
       "node_data" : {
            "ntype0-name" : {
                "feat0-name" : {
                    "format" : {"name": "numpy"},
                    "data" :   [ #list
                        "<path>/feat-0.npy",
                        "<path>/feat-1.npy",
                        ....
                        "<path>/feat-<p-1>.npy"
                    ]
                },
                "feat1-name" : {
170
171
                    "format" : {"name": "numpy"},
                    "data" : [ #list
172
173
174
175
176
177
178
179
180
                        "<path>/feat-0.npy",
                        "<path>/feat-1.npy",
                        ....
                        "<path>/feat-<p-1>.npy"
                    ]
                }
            }
       }

181
    As shown above, the value for the key "node_data" is a dictionary object, which is
182
183
184
185
186
187
188
189
    used to describe the feature data for each of the node-type names. Keys in this top-level
    dictionary are node-type names and value is a dictionary which captures all the features
    for the current node-type. Feature data is captured with keys being the feature-names and
    value is a dictionary object which has 2 keys namely format and data. Format entry is used
    to mention the format of the storage used by the node features themselves and "data" is used
    to mention all the files present for this given node feature.

    Data read from each of the node features file is a multi-dimensional tensor data and is read
190
    in numpy or parquet format, which is also the storage format of node features on the permanent storage.
191

192
193
194
        "node_type" : ["ntype0-name", "ntype1-name", ....], #m node types
        "num_nodes_per_chunk" : [
            [a0, a1, ...a<p-1>], #p partitions
195
            [b0, b1, ... b<p-1>],
196
197
198
199
200
201
202
203
204
205
206
207
            ....
            [c0, c1, ..., c<p-1>] #no, of node types
        ],

    The "node_type" points to a list of all the node names present in the graph
    And "num_nodes_per_chunk" is used to mention no. of nodes present in each of the
    input nodes files. These node counters are used to compute the type_node_ids as
    well as global node-ids by using a simple cumulative summation and maitaining an
    offset counter to store the end of the current.

    Since nodes are NOT actually associated with any additional metadata, w.r.t to the processing
    involved in this pipeline this information is not needed to be stored in files. This optimization
208
    saves a considerable amount of time when loading massively large datasets for paritioning.
209
210
211
212
    As opposed to reading from files and performing shuffling process each process/rank generates nodes
    which are owned by that particular rank. And using the "num_nodes_per_chunk" information each
    process can easily compute any nodes per-type node_id and global node_id.
    The node-ids are treated as int64's in order to support billions of nodes in the input graph.
213
    """
214

215
    # read my nodes for each node type
216
    """
217
218
219
220
221
    node_tids, ntype_gnid_offset = get_idranges(
        schema_map[constants.STR_NODE_TYPE],
        schema_map[constants.STR_NUM_NODES_PER_CHUNK],
        num_chunks=num_parts,
    )
222
    """
223
    logging.debug(f"[Rank: {rank} ntype_counts: {ntype_counts}")
224
225
226
    ntype_gnid_offset = get_gid_offsets(
        schema_map[constants.STR_NODE_TYPE], ntype_counts
    )
227
    logging.debug(f"[Rank: {rank} - ntype_gnid_offset = {ntype_gnid_offset}")
228

229
230
231
    # iterate over the "node_data" dictionary in the schema_map
    # read the node features if exists
    # also keep track of the type_nids for which the node_features are read.
232
    dataset_features = schema_map[constants.STR_NODE_DATA]
233
    if (dataset_features is not None) and (len(dataset_features) > 0):
234
235
        for ntype_name, ntype_feature_data in dataset_features.items():
            for feat_name, feat_data in ntype_feature_data.items():
236
237
238
239
                assert feat_data[constants.STR_FORMAT][constants.STR_NAME] in [
                    constants.STR_NUMPY,
                    constants.STR_PARQUET,
                ]
240

241
                # It is guaranteed that num_chunks is always greater
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
                # than num_partitions.
                node_data = []
                num_files = len(feat_data[constants.STR_DATA])
                if num_files == 0:
                    continue
                reader_fmt_meta = {
                    "name": feat_data[constants.STR_FORMAT][constants.STR_NAME]
                }
                read_list = generate_read_list(num_files, world_size)
                for idx in read_list[rank]:
                    data_file = feat_data[constants.STR_DATA][idx]
                    if not os.path.isabs(data_file):
                        data_file = os.path.join(input_dir, data_file)
                    node_data.append(
                        array_readwriter.get_array_parser(
                            **reader_fmt_meta
                        ).read(data_file)
                    )
                if len(node_data) > 0:
                    node_data = np.concatenate(node_data)
                else:
                    node_data = np.array([])
                node_data = torch.from_numpy(node_data)
265
                cur_tids = _broadcast_shape(
266
267
268
                    node_data,
                    rank,
                    world_size,
269
                    num_parts,
270
271
                    True,
                    f"{ntype_name}/{feat_name}",
272
                )
273
                logging.debug(f"[Rank: {rank} - cur_tids: {cur_tids}")
274
275

                # collect data on current rank.
276
                for local_part_id in range(num_parts):
277
278
279
                    data_key = (
                        f"{ntype_name}/{feat_name}/{local_part_id//world_size}"
                    )
280
                    if map_partid_rank(local_part_id, world_size) == rank:
281
282
283
284
285
286
287
288
289
290
                        if len(cur_tids) > local_part_id:
                            start, end = cur_tids[local_part_id]
                            assert node_data.shape[0] == (
                                end - start
                            ), f"Node feature data, {data_key}, shape = {node_data.shape} does not match with tids = ({start},{end})"
                            node_features[data_key] = node_data
                            node_feature_tids[data_key] = [(start, end)]
                        else:
                            node_features[data_key] = None
                            node_feature_tids[data_key] = [(0, 0)]
291

292
    # done building node_features locally.
293
    if len(node_features) <= 0:
294
        logging.debug(
295
296
            f"[Rank: {rank}] This dataset does not have any node features"
        )
297
    else:
298
299
300
        assert len(node_features) == len(node_feature_tids)

        # Note that the keys in the node_features dictionary are as follows:
301
302
        # `ntype_name/feat_name/local_part_id`.
        #   where ntype_name and feat_name are self-explanatory, and
303
304
        #   local_part_id indicates the partition-id, in the context of current
        #   process which take the values 0, 1, 2, ....
305
        for feat_name, feat_info in node_features.items():
306
307
308
            if feat_info == None:
                continue

309
            logging.debug(
310
311
                f"[Rank: {rank}] node feature name: {feat_name}, feature data shape: {feat_info.size()}"
            )
312
313
314
315
316
317
318
319
320
            tokens = feat_name.split("/")
            assert len(tokens) == 3

            # Get the range of type ids which are mapped to the current node.
            tids = node_feature_tids[feat_name]

            # Iterate over the range of type ids for the current node feature
            # and count the number of features for this feature name.
            count = tids[0][1] - tids[0][0]
321
322
323
            assert (
                count == feat_info.size()[0]
            ), f"{feat_name}, {count} vs {feat_info.size()[0]}."
324

325
    """
326
    Reading edge features now.
327
    The structure of the edge_data is as follows, which is present in the input metadata json file.
328
329
330
331
332
333
334
335
336
337
338
339
       "edge_data" : {
            "etype0-name" : {
                "feat0-name" : {
                    "format" : {"name": "numpy"},
                    "data" :   [ #list
                        "<path>/feat-0.npy",
                        "<path>/feat-1.npy",
                        ....
                        "<path>/feat-<p-1>.npy"
                    ]
                },
                "feat1-name" : {
340
341
                    "format" : {"name": "numpy"},
                    "data" : [ #list
342
343
344
345
346
347
348
349
350
                        "<path>/feat-0.npy",
                        "<path>/feat-1.npy",
                        ....
                        "<path>/feat-<p-1>.npy"
                    ]
                }
            }
       }

351
    As shown above, the value for the key "edge_data" is a dictionary object, which is
352
353
354
355
356
357
358
359
360
    used to describe the feature data for each of the edge-type names. Keys in this top-level
    dictionary are edge-type names and value is a dictionary which captures all the features
    for the current edge-type. Feature data is captured with keys being the feature-names and
    value is a dictionary object which has 2 keys namely `format` and `data`. Format entry is used
    to mention the format of the storage used by the node features themselves and "data" is used
    to mention all the files present for this given node feature.

    Data read from each of the node features file is a multi-dimensional tensor data and is read
    in numpy format, which is also the storage format of node features on the permanent storage.
361
    """
362
363
364
365
366
367
368
369
370
371
    edge_features = {}
    edge_feature_tids = {}

    # Iterate over the "edge_data" dictionary in the schema_map.
    # Read the edge features if exists.
    # Also keep track of the type_eids for which the edge_features are read.
    dataset_features = schema_map[constants.STR_EDGE_DATA]
    if dataset_features and (len(dataset_features) > 0):
        for etype_name, etype_feature_data in dataset_features.items():
            for feat_name, feat_data in etype_feature_data.items():
372
373
374
375
                assert feat_data[constants.STR_FORMAT][constants.STR_NAME] in [
                    constants.STR_NUMPY,
                    constants.STR_PARQUET,
                ]
376
377
378
379
380
381
382
383
384
385
386
387
388

                edge_data = []
                num_files = len(feat_data[constants.STR_DATA])
                if num_files == 0:
                    continue
                reader_fmt_meta = {
                    "name": feat_data[constants.STR_FORMAT][constants.STR_NAME]
                }
                read_list = generate_read_list(num_files, world_size)
                for idx in read_list[rank]:
                    data_file = feat_data[constants.STR_DATA][idx]
                    if not os.path.isabs(data_file):
                        data_file = os.path.join(input_dir, data_file)
389
                    logging.debug(
390
391
                        f"[Rank: {rank}] Loading edges-feats of {etype_name}[{feat_name}] from {data_file}"
                    )
392
393
394
395
396
397
398
399
400
401
402
                    edge_data.append(
                        array_readwriter.get_array_parser(
                            **reader_fmt_meta
                        ).read(data_file)
                    )
                if len(edge_data) > 0:
                    edge_data = np.concatenate(edge_data)
                else:
                    edge_data = np.array([])
                edge_data = torch.from_numpy(edge_data)

403
404
                # exchange the amount of data read from the disk.
                edge_tids = _broadcast_shape(
405
406
407
                    edge_data,
                    rank,
                    world_size,
408
                    num_parts,
409
410
                    True,
                    f"{etype_name}/{feat_name}",
411
                )
412
413

                # collect data on current rank.
414
                for local_part_id in range(num_parts):
415
416
417
                    data_key = (
                        f"{etype_name}/{feat_name}/{local_part_id//world_size}"
                    )
418
                    if map_partid_rank(local_part_id, world_size) == rank:
419
420
421
422
423
424
425
426
427
                        if len(edge_tids) > local_part_id:
                            start, end = edge_tids[local_part_id]
                            assert edge_data.shape[0] == (
                                end - start
                            ), f"Edge Feature data, for {data_key}, of shape = {edge_data.shape} does not match with tids = ({start}, {end})"
                            edge_features[data_key] = edge_data
                            edge_feature_tids[data_key] = [(start, end)]
                        else:
                            edge_features[data_key] = None
428
                            edge_feature_tids[data_key] = [(0, 0)]
429

430
    # Done with building node_features locally.
431
    if len(edge_features) <= 0:
432
        logging.debug(
433
434
            f"[Rank: {rank}] This dataset does not have any edge features"
        )
435
    else:
436
        assert len(edge_features) == len(edge_feature_tids)
437

438
        for k, v in edge_features.items():
439
440
            if v == None:
                continue
441
            logging.debug(
442
443
                f"[Rank: {rank}] edge feature name: {k}, feature data shape: {v.shape}"
            )
444
445
446
            tids = edge_feature_tids[k]
            count = tids[0][1] - tids[0][0]
            assert count == v.size()[0]
447

448
    """
449
    Code below is used to read edges from the input dataset with the help of the metadata json file
450
451
452
    for the input graph dataset.
    In the metadata json file, we expect the following key-value pairs to help read the edges of the
    input graph.
453
454

    "edge_type" : [ # a total of n edge types
455
456
457
        canonical_etype_0,
        canonical_etype_1,
        ...,
458
459
460
461
462
        canonical_etype_n-1
    ]

    The value for the key is a list of strings, each string is associated with an edgetype in the input graph.
    Note that these strings are in canonical edgetypes format. This means, these edge type strings follow the
463
464
465
466
    following naming convention: src_ntype:etype:dst_ntype. src_ntype and dst_ntype are node type names of the
    src and dst end points of this edge type, and etype is the relation name between src and dst ntypes.

    The files in which edges are present and their storage format are present in the following key-value pair:
467
468
469

    "edges" : {
        "canonical_etype_0" : {
470
            "format" : { "name" : "csv", "delimiter" : " " },
471
            "data" : [
472
473
474
                filename_0,
                filename_1,
                filename_2,
475
476
477
478
479
480
481
                ....
                filename_<p-1>
            ]
        },
    }

    As shown above the "edges" dictionary value has canonical edgetypes as keys and for each canonical edgetype
482
    we have "format" and "data" which describe the storage format of the edge files and actual filenames respectively.
483
484
485
486
487
    Please note that each edgetype data is split in to `p` files, where p is the no. of partitions to be made of
    the input graph.

    Each edge file contains two columns representing the source per-type node_ids and destination per-type node_ids
    of any given edge. Since these are node-ids as well they are read in as int64's.
488
    """
489

490
    # read my edges for each edge type
491
    etype_names = schema_map[constants.STR_EDGE_TYPE]
492
    etype_name_idmap = {e: idx for idx, e in enumerate(etype_names)}
493

494
495
    edge_tids = {}
    edge_typecounts = {}
496
    edge_datadict = {}
497
498
    edge_data = schema_map[constants.STR_EDGES]

499
500
501
502
503
504
505
    # read the edges files and store this data in memory.
    for col in [
        constants.GLOBAL_SRC_ID,
        constants.GLOBAL_DST_ID,
        constants.GLOBAL_TYPE_EID,
        constants.ETYPE_ID,
    ]:
506
507
        edge_datadict[col] = []

508
509
    for etype_name, etype_id in etype_name_idmap.items():
        etype_info = edge_data[etype_name]
510
        edge_info = etype_info[constants.STR_DATA]
511

512
        # edgetype strings are in canonical format, src_node_type:edge_type:dst_node_type
513
514
515
516
517
518
        tokens = etype_name.split(":")
        assert len(tokens) == 3

        src_ntype_name = tokens[0]
        dst_ntype_name = tokens[2]

519
        num_chunks = len(edge_info)
520
        read_list = generate_read_list(num_chunks, world_size)
521
522
        src_ids = []
        dst_ids = []
523

524
        """
525
526
527
528
529
530
        curr_partids = []
        for part_id in range(num_parts):
            if map_partid_rank(part_id, world_size) == rank:
                curr_partids.append(read_list[part_id])

        for idx in np.concatenate(curr_partids):
531
532
        """
        for idx in read_list[rank]:
533
534
535
            edge_file = edge_info[idx]
            if not os.path.isabs(edge_file):
                edge_file = os.path.join(input_dir, edge_file)
536
            logging.debug(
537
                f"[Rank: {rank}] Loading edges of etype[{etype_name}] from {edge_file}"
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
            )

            if (
                etype_info[constants.STR_FORMAT][constants.STR_NAME]
                == constants.STR_CSV
            ):
                read_options = pyarrow.csv.ReadOptions(
                    use_threads=True,
                    block_size=4096,
                    autogenerate_column_names=True,
                )
                parse_options = pyarrow.csv.ParseOptions(delimiter=" ")
                with pyarrow.csv.open_csv(
                    edge_file,
                    read_options=read_options,
                    parse_options=parse_options,
                ) as reader:
555
556
557
558
559
                    for next_chunk in reader:
                        if next_chunk is None:
                            break

                        next_table = pyarrow.Table.from_batches([next_chunk])
560
561
562
563
564
565
                        src_ids.append(next_table["f0"].to_numpy())
                        dst_ids.append(next_table["f1"].to_numpy())
            elif (
                etype_info[constants.STR_FORMAT][constants.STR_NAME]
                == constants.STR_PARQUET
            ):
566
567
                data_df = pq.read_table(edge_file)
                data_df = data_df.rename_columns(["f0", "f1"])
568
569
                src_ids.append(data_df["f0"].to_numpy())
                dst_ids.append(data_df["f1"].to_numpy())
570
            else:
571
572
573
                raise ValueError(
                    f"Unknown edge format {etype_info[constants.STR_FORMAT][constants.STR_NAME]} for edge type {etype_name}"
                )
574

575
576
577
        if len(src_ids) > 0:
            src_ids = np.concatenate(src_ids)
            dst_ids = np.concatenate(dst_ids)
578

579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
            # currently these are just type_edge_ids... which will be converted to global ids
            edge_datadict[constants.GLOBAL_SRC_ID].append(
                src_ids + ntype_gnid_offset[src_ntype_name][0]
            )
            edge_datadict[constants.GLOBAL_DST_ID].append(
                dst_ids + ntype_gnid_offset[dst_ntype_name][0]
            )
            edge_datadict[constants.ETYPE_ID].append(
                etype_name_idmap[etype_name]
                * np.ones(shape=(src_ids.shape), dtype=np.int64)
            )
        else:
            src_ids = np.array([])

        # broadcast shape to compute the etype_id, and global_eid's later.
        cur_tids = _broadcast_shape(
            src_ids, rank, world_size, num_parts, False, None
596
        )
597
598
        edge_typecounts[etype_name] = cur_tids[-1][1]
        edge_tids[etype_name] = cur_tids
599
600

        for local_part_id in range(num_parts):
601
            if map_partid_rank(local_part_id, world_size) == rank:
602
603
604
605
606
607
608
                if len(cur_tids) > local_part_id:
                    edge_datadict[constants.GLOBAL_TYPE_EID].append(
                        np.arange(
                            cur_tids[local_part_id][0],
                            cur_tids[local_part_id][1],
                            dtype=np.int64,
                        )
609
                    )
610
611
612
613
                    # edge_tids[etype_name] = [(cur_tids[local_part_id][0], cur_tids[local_part_id][1])]
                    assert len(edge_datadict[constants.GLOBAL_SRC_ID]) == len(
                        edge_datadict[constants.GLOBAL_TYPE_EID]
                    ), f"Error while reading edges from the disk, local_part_id = {local_part_id}, num_parts = {num_parts}, world_size = {world_size} cur_tids = {cur_tids}"
614
615
616
617
618
619
620
621

    # stitch together to create the final data on the local machine
    for col in [
        constants.GLOBAL_SRC_ID,
        constants.GLOBAL_DST_ID,
        constants.GLOBAL_TYPE_EID,
        constants.ETYPE_ID,
    ]:
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
        if len(edge_datadict[col]) > 0:
            edge_datadict[col] = np.concatenate(edge_datadict[col])

    if len(edge_datadict[constants.GLOBAL_SRC_ID]) > 0:
        assert (
            edge_datadict[constants.GLOBAL_SRC_ID].shape
            == edge_datadict[constants.GLOBAL_DST_ID].shape
        )
        assert (
            edge_datadict[constants.GLOBAL_DST_ID].shape
            == edge_datadict[constants.GLOBAL_TYPE_EID].shape
        )
        assert (
            edge_datadict[constants.GLOBAL_TYPE_EID].shape
            == edge_datadict[constants.ETYPE_ID].shape
        )
638
        logging.debug(
639
640
641
642
643
644
645
646
647
648
649
            f"[Rank: {rank}] Done reading edge_file: {len(edge_datadict)}, {edge_datadict[constants.GLOBAL_SRC_ID].shape}"
        )
    else:
        assert edge_datadict[constants.GLOBAL_SRC_ID] == []
        assert edge_datadict[constants.GLOBAL_DST_ID] == []
        assert edge_datadict[constants.GLOBAL_TYPE_EID] == []

        edge_datadict[constants.GLOBAL_SRC_ID] = np.array([], dtype=np.int64)
        edge_datadict[constants.GLOBAL_DST_ID] = np.array([], dtype=np.int64)
        edge_datadict[constants.GLOBAL_TYPE_EID] = np.array([], dtype=np.int64)
        edge_datadict[constants.ETYPE_ID] = np.array([], dtype=np.int64)
650

651
    logging.debug(f"Rank: {rank} edge_feat_tids: {edge_feature_tids}")
652
653
654
655
656

    return (
        node_features,
        node_feature_tids,
        edge_datadict,
657
        edge_typecounts,
658
659
660
661
        edge_tids,
        edge_features,
        edge_feature_tids,
    )