Unverified Commit 61edb798 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] Update GNNExplainer (#3848)

* Update

* Update

* Update
parent 8005978e
...@@ -16,6 +16,16 @@ class GNNExplainer(nn.Module): ...@@ -16,6 +16,16 @@ class GNNExplainer(nn.Module):
It identifies compact subgraph structures and small subsets of node features that play a It identifies compact subgraph structures and small subsets of node features that play a
critical role in GNN-based node classification and graph classification. critical role in GNN-based node classification and graph classification.
To generate an explanation, it learns an edge mask :math:`M` and a feature mask :math:`F`
by optimizing the following objective function.
.. math::
l(y, \hat{y}) + \alpha_1 \|M\|_1 + \alpha_2 H(M) + \beta_1 \|F\|_1 + \beta_2 H(F)
where :math:`l` is the loss function, :math:`y` is the original model prediction,
:math:`\hat{y}` is the model prediction with the edge and feature mask applied, :math:`H` is
the entropy function.
Parameters Parameters
---------- ----------
model : nn.Module model : nn.Module
...@@ -35,28 +45,42 @@ class GNNExplainer(nn.Module): ...@@ -35,28 +45,42 @@ class GNNExplainer(nn.Module):
The learning rate to use, default to 0.01. The learning rate to use, default to 0.01.
num_epochs : int, optional num_epochs : int, optional
The number of epochs to train. The number of epochs to train.
alpha1 : float, optional
A higher value will make the explanation edge masks more sparse by decreasing
the sum of the edge mask.
alpha2 : float, optional
A higher value will make the explanation edge masks more sparse by decreasing
the entropy of the edge mask.
beta1 : float, optional
A higher value will make the explanation node feature masks more sparse by
decreasing the mean of the node feature mask.
beta2 : float, optional
A higher value will make the explanation node feature masks more sparse by
decreasing the entropy of the node feature mask.
log : bool, optional log : bool, optional
If True, it will log the computation process, default to True. If True, it will log the computation process, default to True.
""" """
coeffs = {
'edge_size': 0.005,
'edge_ent': 1.0,
'node_feat_size': 1.0,
'node_feat_ent': 0.1
}
def __init__(self, def __init__(self,
model, model,
num_hops, num_hops,
lr=0.01, lr=0.01,
num_epochs=100, num_epochs=100,
*,
alpha1=0.005,
alpha2=1.0,
beta1=1.0,
beta2=0.1,
log=True): log=True):
super(GNNExplainer, self).__init__() super(GNNExplainer, self).__init__()
self.model = model self.model = model
self.num_hops = num_hops self.num_hops = num_hops
self.lr = lr self.lr = lr
self.num_epochs = num_epochs self.num_epochs = num_epochs
self.alpha1 = alpha1
self.alpha2 = alpha2
self.beta1 = beta1
self.beta2 = beta2
self.log = log self.log = log
def _init_masks(self, graph, feat): def _init_masks(self, graph, feat):
...@@ -114,19 +138,19 @@ class GNNExplainer(nn.Module): ...@@ -114,19 +138,19 @@ class GNNExplainer(nn.Module):
edge_mask = edge_mask.sigmoid() edge_mask = edge_mask.sigmoid()
# Edge mask sparsity regularization # Edge mask sparsity regularization
loss = loss + self.coeffs['edge_size'] * torch.sum(edge_mask) loss = loss + self.alpha1 * torch.sum(edge_mask)
# Edge mask entropy regularization # Edge mask entropy regularization
ent = - edge_mask * torch.log(edge_mask + eps) - \ ent = - edge_mask * torch.log(edge_mask + eps) - \
(1 - edge_mask) * torch.log(1 - edge_mask + eps) (1 - edge_mask) * torch.log(1 - edge_mask + eps)
loss = loss + self.coeffs['edge_ent'] * ent.mean() loss = loss + self.alpha2 * ent.mean()
feat_mask = feat_mask.sigmoid() feat_mask = feat_mask.sigmoid()
# Feature mask sparsity regularization # Feature mask sparsity regularization
loss = loss + self.coeffs['node_feat_size'] * torch.mean(feat_mask) loss = loss + self.beta1 * torch.mean(feat_mask)
# Feature mask entropy regularization # Feature mask entropy regularization
ent = -feat_mask * torch.log(feat_mask + eps) - \ ent = - feat_mask * torch.log(feat_mask + eps) - \
(1 - feat_mask) * torch.log(1 - feat_mask + eps) (1 - feat_mask) * torch.log(1 - feat_mask + eps)
loss = loss + self.coeffs['node_feat_ent'] * ent.mean() loss = loss + self.beta2 * ent.mean()
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment