Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
61edb798
Unverified
Commit
61edb798
authored
Mar 23, 2022
by
Mufei Li
Committed by
GitHub
Mar 23, 2022
Browse files
[NN] Update GNNExplainer (#3848)
* Update * Update * Update
parent
8005978e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
12 deletions
+36
-12
python/dgl/nn/pytorch/explain/gnnexplainer.py
python/dgl/nn/pytorch/explain/gnnexplainer.py
+36
-12
No files found.
python/dgl/nn/pytorch/explain/gnnexplainer.py
View file @
61edb798
...
...
@@ -16,6 +16,16 @@ class GNNExplainer(nn.Module):
It identifies compact subgraph structures and small subsets of node features that play a
critical role in GNN-based node classification and graph classification.
To generate an explanation, it learns an edge mask :math:`M` and a feature mask :math:`F`
by optimizing the following objective function.
.. math::
l(y, \hat{y}) + \alpha_1 \|M\|_1 + \alpha_2 H(M) + \beta_1 \|F\|_1 + \beta_2 H(F)
where :math:`l` is the loss function, :math:`y` is the original model prediction,
:math:`\hat{y}` is the model prediction with the edge and feature mask applied, :math:`H` is
the entropy function.
Parameters
----------
model : nn.Module
...
...
@@ -35,28 +45,42 @@ class GNNExplainer(nn.Module):
The learning rate to use, default to 0.01.
num_epochs : int, optional
The number of epochs to train.
alpha1 : float, optional
A higher value will make the explanation edge masks more sparse by decreasing
the sum of the edge mask.
alpha2 : float, optional
A higher value will make the explanation edge masks more sparse by decreasing
the entropy of the edge mask.
beta1 : float, optional
A higher value will make the explanation node feature masks more sparse by
decreasing the mean of the node feature mask.
beta2 : float, optional
A higher value will make the explanation node feature masks more sparse by
decreasing the entropy of the node feature mask.
log : bool, optional
If True, it will log the computation process, default to True.
"""
coeffs
=
{
'edge_size'
:
0.005
,
'edge_ent'
:
1.0
,
'node_feat_size'
:
1.0
,
'node_feat_ent'
:
0.1
}
def
__init__
(
self
,
model
,
num_hops
,
lr
=
0.01
,
num_epochs
=
100
,
*
,
alpha1
=
0.005
,
alpha2
=
1.0
,
beta1
=
1.0
,
beta2
=
0.1
,
log
=
True
):
super
(
GNNExplainer
,
self
).
__init__
()
self
.
model
=
model
self
.
num_hops
=
num_hops
self
.
lr
=
lr
self
.
num_epochs
=
num_epochs
self
.
alpha1
=
alpha1
self
.
alpha2
=
alpha2
self
.
beta1
=
beta1
self
.
beta2
=
beta2
self
.
log
=
log
def
_init_masks
(
self
,
graph
,
feat
):
...
...
@@ -114,19 +138,19 @@ class GNNExplainer(nn.Module):
edge_mask
=
edge_mask
.
sigmoid
()
# Edge mask sparsity regularization
loss
=
loss
+
self
.
coeffs
[
'edge_size'
]
*
torch
.
sum
(
edge_mask
)
loss
=
loss
+
self
.
alpha1
*
torch
.
sum
(
edge_mask
)
# Edge mask entropy regularization
ent
=
-
edge_mask
*
torch
.
log
(
edge_mask
+
eps
)
-
\
(
1
-
edge_mask
)
*
torch
.
log
(
1
-
edge_mask
+
eps
)
loss
=
loss
+
self
.
coeffs
[
'edge_ent'
]
*
ent
.
mean
()
loss
=
loss
+
self
.
alpha2
*
ent
.
mean
()
feat_mask
=
feat_mask
.
sigmoid
()
# Feature mask sparsity regularization
loss
=
loss
+
self
.
coeffs
[
'node_feat_size'
]
*
torch
.
mean
(
feat_mask
)
loss
=
loss
+
self
.
beta1
*
torch
.
mean
(
feat_mask
)
# Feature mask entropy regularization
ent
=
-
feat_mask
*
torch
.
log
(
feat_mask
+
eps
)
-
\
ent
=
-
feat_mask
*
torch
.
log
(
feat_mask
+
eps
)
-
\
(
1
-
feat_mask
)
*
torch
.
log
(
1
-
feat_mask
+
eps
)
loss
=
loss
+
self
.
coeffs
[
'node_feat_ent'
]
*
ent
.
mean
()
loss
=
loss
+
self
.
beta2
*
ent
.
mean
()
return
loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment