explain_main.py 3.56 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# The major idea of the overall GNN model explanation

import argparse
import os
import dgl

import torch as th
import torch.nn as nn
import torch.nn.functional as F

from dgl import load_graphs
from models import dummy_gnn_model
from NodeExplainerModule import NodeExplainerModule
from utils_graph import extract_subgraph, visualize_sub_graph


def main(args):
    # load an exisitng model or ask for training a model
    model_path = os.path.join('./', 'dummy_model_{}.pth'.format(args.dataset))
    if os.path.exists(model_path):
        model_stat_dict = th.load(model_path)
    else:
        raise FileExistsError('No Saved Model file. Please train a GNN model first...')

    # load graph, feat, and label
    g_list, label_dict = load_graphs('./'+args.dataset+'.bin')
    graph = g_list[0]
    labels = graph.ndata['label']
    feats = graph.ndata['feat']
    num_classes = max(labels).item() + 1
    feat_dim = feats.shape[1]
    hid_dim = label_dict['hid_dim'].item()
    
    # create a model and load from state_dict
    dummy_model = dummy_gnn_model(feat_dim, hid_dim, num_classes)
    dummy_model.load_state_dict(model_stat_dict)

    # Choose a node of the target class to be explained and extract its subgraph.
    # Here just pick the first one of the target class.
    target_list = [i for i, e in enumerate(labels) if e==args.target_class]
    n_idx = th.tensor([target_list[0]])

    # Extract the computation graph within k-hop of target node and use it for explainability
    sub_graph, ori_n_idxes, new_n_idx = extract_subgraph(graph, n_idx, hops=args.hop)
    
    #Sub-graph features.
    sub_feats = feats[ori_n_idxes,:]

    # create an explainer
    explainer = NodeExplainerModule(model=dummy_model,
                                    num_edges=sub_graph.number_of_edges(),
                                    node_feat_dim=feat_dim)

    # define optimizer
55
    optim = th.optim.Adam([explainer.edge_mask, explainer.node_feat_mask], lr=args.lr, weight_decay=args.wd)
KounianhuaDu's avatar
KounianhuaDu committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

    # train the explainer for the given node
    dummy_model.eval()
    model_logits = dummy_model(sub_graph, sub_feats)
    model_predict = F.one_hot(th.argmax(model_logits, dim=-1), num_classes)

    for epoch in range(args.epochs):
        explainer.train()
        exp_logits = explainer(sub_graph, sub_feats)
        loss = explainer._loss(exp_logits[new_n_idx], model_predict[new_n_idx])

        optim.zero_grad()
        loss.backward()
        optim.step()

    # visualize the importance of edges
    edge_weights = explainer.edge_mask.sigmoid().detach()
    visualize_sub_graph(sub_graph, edge_weights.numpy(), ori_n_idxes, n_idx)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Demo of GNN explainer in DGL')
    parser.add_argument('--dataset', type=str, default='syn1',
                        help='The dataset to be explained.')
    parser.add_argument('--target_class', type=int, default='1',
                        help='The class to be explained. In the synthetic 1 dataset, Valid option is from 0 to 4'
                             'Will choose the first node in this class to explain')
    parser.add_argument('--hop', type=int, default='2',
                        help='The hop number of the computation sub-graph. For syn1 and syn2, k=2. For syn3, syn4, and syn5, k=4.')
    parser.add_argument('--epochs', type=int, default=200, help='The number of epochs.')
    parser.add_argument('--lr', type=float, default=0.01, help='The learning rate.')
    parser.add_argument('--wd', type=float, default=0.0, help='Weight decay.')
    args = parser.parse_args()
    print(args)

    main(args)