"docs/vscode:/vscode.git/clone" did not exist on "e9ef39d2e94ab30c1615456b1bf4dddcddc82f2b"
gen_dataset_stat.py 2.45 KB
Newer Older
1
2
3
4
from pytablewriter import RstGridTableWriter, MarkdownTableWriter
import numpy as np
import pandas as pd
from dgl import DGLGraph
5
from dgl.data.gnn_benchmark import AmazonCoBuy, CoraFull, Coauthor
6
7
8
9
10
11
12
from dgl.data.karate import KarateClub
from dgl.data.gindt import GINDataset
from dgl.data.bitcoinotc import BitcoinOTC
from dgl.data.gdelt import GDELT
from dgl.data.icews18 import ICEWS18
from dgl.data.qm7b import QM7b
# from dgl.data.qm9 import QM9
Mufei Li's avatar
Mufei Li committed
13
from dgl.data import CitationGraphDataset, PPIDataset, RedditDataset, TUDataset
14
15
16

ds_list = {
    "BitcoinOTC": "BitcoinOTC()",
Mufei Li's avatar
Mufei Li committed
17
    "Cora": "CitationGraphDataset('cora')",
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    "Citeseer": "CitationGraphDataset('citeseer')",
    "PubMed": "CitationGraphDataset('pubmed')",
    "QM7b": "QM7b()",
    "Reddit": "RedditDataset()",
    "ENZYMES": "TUDataset('ENZYMES')",
    "DD": "TUDataset('DD')",
    "COLLAB": "TUDataset('COLLAB')",
    "MUTAG": "TUDataset('MUTAG')",
    "PROTEINS": "TUDataset('PROTEINS')",
    "PPI": "PPIDataset('train')/PPIDataset('valid')/PPIDataset('test')",
    # "Cora Binary": "CitationGraphDataset('cora_binary')",
    "KarateClub": "KarateClub()",
    "Amazon computer": "AmazonCoBuy('computers')",
    "Amazon photo": "AmazonCoBuy('photo')",
    "Coauthor cs": "Coauthor('cs')",
    "Coauthor physics": "Coauthor('physics')",
    "GDELT": "GDELT('train')/GDELT('valid')/GDELT('test')",
    "ICEWS18": "ICEWS18('train')/ICEWS18('valid')/ICEWS18('test')",
    "CoraFull": "CoraFull()",
}

writer = RstGridTableWriter()
# writer = MarkdownTableWriter()

extract_graph = lambda g: g if isinstance(g, DGLGraph) else g[0]
stat_list=[]
for k,v in ds_list.items():
    print(k, ' ', v)
    ds = eval(v.split("/")[0])
    num_nodes = []
    num_edges = []
    for i in range(len(ds)):
        g = extract_graph(ds[i])
        num_nodes.append(g.number_of_nodes())
        num_edges.append(g.number_of_edges())

    gg = extract_graph(ds[0])
    dd = {
        "Datset Name": k,
        "Usage": v,
        "# of graphs": len(ds),
        "Avg. # of nodes": np.mean(num_nodes),
        "Avg. # of edges": np.mean(num_edges),
        "Node field": ', '.join(list(gg.ndata.keys())),
        "Edge field": ', '.join(list(gg.edata.keys())),
        # "Graph field": ', '.join(ds[0][0].gdata.keys()) if hasattr(ds[0][0], "gdata") else "",
        "Temporal": hasattr(ds, "is_temporal")
    }
    stat_list.append(dd)

print(dd.keys())
df = pd.DataFrame(stat_list)
df = df.reindex(columns=dd.keys())
writer.from_dataframe(df)

writer.write_table()