import dgl import json import torch as th import numpy as np from ogb.nodeproppred import DglNodePropPredDataset # Load OGB-MAG. dataset = DglNodePropPredDataset(name='ogbn-mag') hg_orig, labels = dataset[0] subgs = {} for etype in hg_orig.canonical_etypes: u, v = hg_orig.all_edges(etype=etype) subgs[etype] = (u, v) subgs[(etype[2], 'rev-'+etype[1], etype[0])] = (v, u) hg = dgl.heterograph(subgs) hg.nodes['paper'].data['feat'] = hg_orig.nodes['paper'].data['feat'] split_idx = dataset.get_idx_split() train_idx = split_idx["train"]['paper'] val_idx = split_idx["valid"]['paper'] test_idx = split_idx["test"]['paper'] paper_labels = labels['paper'].squeeze() train_mask = th.zeros((hg.number_of_nodes('paper'),), dtype=th.bool) train_mask[train_idx] = True val_mask = th.zeros((hg.number_of_nodes('paper'),), dtype=th.bool) val_mask[val_idx] = True test_mask = th.zeros((hg.number_of_nodes('paper'),), dtype=th.bool) test_mask[test_idx] = True hg.nodes['paper'].data['train_mask'] = train_mask hg.nodes['paper'].data['val_mask'] = val_mask hg.nodes['paper'].data['test_mask'] = test_mask hg.nodes['paper'].data['labels'] = paper_labels with open('outputs/mag.json') as json_file: metadata = json.load(json_file) for part_id in range(metadata['num_parts']): subg = dgl.load_graphs('outputs/part{}/graph.dgl'.format(part_id))[0][0] node_data = {} for ntype in hg.ntypes: local_node_idx = th.logical_and(subg.ndata['inner_node'].bool(), subg.ndata[dgl.NTYPE] == hg.get_ntype_id(ntype)) local_nodes = subg.ndata['orig_id'][local_node_idx].numpy() for name in hg.nodes[ntype].data: node_data[ntype + '/' + name] = hg.nodes[ntype].data[name][local_nodes] print('node features:', node_data.keys()) dgl.data.utils.save_tensors('outputs/' + metadata['part-{}'.format(part_id)]['node_feats'], node_data) edge_data = {} for etype in hg.etypes: local_edges = subg.edata['orig_id'][subg.edata[dgl.ETYPE] == hg.get_etype_id(etype)] for name in hg.edges[etype].data: edge_data[etype + '/' + name] = hg.edges[etype].data[name][local_edges] print('edge features:', edge_data.keys()) dgl.data.utils.save_tensors('outputs/' + metadata['part-{}'.format(part_id)]['edge_feats'], edge_data)