explain_main.py 3.54 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# The major idea of the overall GNN model explanation

import argparse
import os
import dgl

import torch as th
import torch.nn as nn
import torch.nn.functional as F

from dgl import load_graphs
from models import dummy_gnn_model
from NodeExplainerModule import NodeExplainerModule
from utils_graph import extract_subgraph, visualize_sub_graph


def main(args):
    # load an exisitng model or ask for training a model
    model_path = os.path.join('./', 'dummy_model_{}.pth'.format(args.dataset))
    if os.path.exists(model_path):
        model_stat_dict = th.load(model_path)
    else:
        raise FileExistsError('No Saved Model file. Please train a GNN model first...')

    # load graph, feat, and label
    g_list, label_dict = load_graphs('./'+args.dataset+'.bin')
    graph = g_list[0]
    labels = graph.ndata['label']
    feats = graph.ndata['feat']
    num_classes = max(labels).item() + 1
    feat_dim = feats.shape[1]
    hid_dim = label_dict['hid_dim'].item()
    
    # create a model and load from state_dict
    dummy_model = dummy_gnn_model(feat_dim, hid_dim, num_classes)
    dummy_model.load_state_dict(model_stat_dict)

    # Choose a node of the target class to be explained and extract its subgraph.
    # Here just pick the first one of the target class.
    target_list = [i for i, e in enumerate(labels) if e==args.target_class]
    n_idx = th.tensor([target_list[0]])

    # Extract the computation graph within k-hop of target node and use it for explainability
    sub_graph, ori_n_idxes, new_n_idx = extract_subgraph(graph, n_idx, hops=args.hop)
    
    #Sub-graph features.
    sub_feats = feats[ori_n_idxes,:]

    # create an explainer
    explainer = NodeExplainerModule(model=dummy_model,
                                    num_edges=sub_graph.number_of_edges(),
                                    node_feat_dim=feat_dim)

    # define optimizer
    optim = th.optim.Adam(explainer.parameters(), lr=args.lr, weight_decay=args.wd)

    # train the explainer for the given node
    dummy_model.eval()
    model_logits = dummy_model(sub_graph, sub_feats)
    model_predict = F.one_hot(th.argmax(model_logits, dim=-1), num_classes)

    for epoch in range(args.epochs):
        explainer.train()
        exp_logits = explainer(sub_graph, sub_feats)
        loss = explainer._loss(exp_logits[new_n_idx], model_predict[new_n_idx])

        optim.zero_grad()
        loss.backward()
        optim.step()

    # visualize the importance of edges
    edge_weights = explainer.edge_mask.sigmoid().detach()
    visualize_sub_graph(sub_graph, edge_weights.numpy(), ori_n_idxes, n_idx)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Demo of GNN explainer in DGL')
    parser.add_argument('--dataset', type=str, default='syn1',
                        help='The dataset to be explained.')
    parser.add_argument('--target_class', type=int, default='1',
                        help='The class to be explained. In the synthetic 1 dataset, Valid option is from 0 to 4'
                             'Will choose the first node in this class to explain')
    parser.add_argument('--hop', type=int, default='2',
                        help='The hop number of the computation sub-graph. For syn1 and syn2, k=2. For syn3, syn4, and syn5, k=4.')
    parser.add_argument('--epochs', type=int, default=200, help='The number of epochs.')
    parser.add_argument('--lr', type=float, default=0.01, help='The learning rate.')
    parser.add_argument('--wd', type=float, default=0.0, help='Weight decay.')
    args = parser.parse_args()
    print(args)

    main(args)