explain_main.py 2.17 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
2
import argparse
import os
3
4

import dgl
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6

import torch as th
KounianhuaDu's avatar
KounianhuaDu committed
7
from dgl import load_graphs
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
8
9
10
11
12
13
from dgl.data import (
    BACommunityDataset,
    BAShapeDataset,
    TreeCycleDataset,
    TreeGridDataset,
)
14
from dgl.nn import GNNExplainer
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
15
16
from gnnlens import Writer
from models import Model
KounianhuaDu's avatar
KounianhuaDu committed
17
18
19


def main(args):
20
    if args.dataset == "BAShape":
21
        dataset = BAShapeDataset(seed=0)
22
    elif args.dataset == "BACommunity":
23
        dataset = BACommunityDataset(seed=0)
24
    elif args.dataset == "TreeCycle":
25
        dataset = TreeCycleDataset(seed=0)
26
    elif args.dataset == "TreeGrid":
27
28
29
        dataset = TreeGridDataset(seed=0)

    graph = dataset[0]
30
31
    labels = graph.ndata["label"]
    feats = graph.ndata["feat"]
32
33
34
    num_classes = dataset.num_classes

    # load an existing model
35
    model_path = os.path.join("./", f"model_{args.dataset}.pth")
36
37
38
39
40
41
42
43
44
45
46
    model_stat_dict = th.load(model_path)
    model = Model(feats.shape[-1], num_classes)
    model.load_state_dict(model_stat_dict)

    # Choose the first node of the class 1 for explaining prediction
    target_class = 1
    for n_idx, n_label in enumerate(labels):
        if n_label == target_class:
            break

    explainer = GNNExplainer(model, num_hops=3)
47
48
49
    new_center, sub_graph, feat_mask, edge_mask = explainer.explain_node(
        n_idx, graph, feats
    )
50
51
52

    # gnnlens2
    # Specify the path to create a new directory for dumping data files.
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    writer = Writer("gnn_subgraph")
    writer.add_graph(
        name=args.dataset,
        graph=graph,
        nlabels=labels,
        num_nlabel_types=num_classes,
    )
    writer.add_subgraph(
        graph_name=args.dataset,
        subgraph_name="GNNExplainer",
        node_id=n_idx,
        subgraph_nids=sub_graph.ndata[dgl.NID],
        subgraph_eids=sub_graph.edata[dgl.EID],
        subgraph_eweights=edge_mask,
    )
68

69
    # Finish dumping.
70
    writer.close()
KounianhuaDu's avatar
KounianhuaDu committed
71

72
73
74
75
76
77
78
79
80

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Demo of GNN explainer in DGL")
    parser.add_argument(
        "--dataset",
        type=str,
        default="BAShape",
        choices=["BAShape", "BACommunity", "TreeCycle", "TreeGrid"],
    )
KounianhuaDu's avatar
KounianhuaDu committed
81
82
83
84
    args = parser.parse_args()
    print(args)

    main(args)