Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
70ad5083
Unverified
Commit
70ad5083
authored
Jul 09, 2023
by
Nick Baker
Committed by
GitHub
Jul 10, 2023
Browse files
[Model] Add Node explanation for Homogenous PGExplainer Impl. (#5839)
Co-authored-by:
Mufei Li
<
mufeili1996@gmail.com
>
parent
e2d35f62
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
293 additions
and
30 deletions
+293
-30
docs/source/_templates/classtemplate.rst
docs/source/_templates/classtemplate.rst
+1
-1
python/dgl/nn/pytorch/explain/pgexplainer.py
python/dgl/nn/pytorch/explain/pgexplainer.py
+263
-24
tests/python/pytorch/nn/test_nn.py
tests/python/pytorch/nn/test_nn.py
+29
-5
No files found.
docs/source/_templates/classtemplate.rst
View file @
70ad5083
...
@@ -7,4 +7,4 @@
...
@@ -7,4 +7,4 @@
.. autoclass:: {{ name }}
.. autoclass:: {{ name }}
:show-inheritance:
:show-inheritance:
:members: __getitem__, __len__, collate_fn, forward, reset_parameters, rel_emb, rel_project, explain_node, explain_graph, train_step
:members: __getitem__, __len__, collate_fn, forward, reset_parameters, rel_emb, rel_project, explain_node, explain_graph, train_step
, train_step_node
python/dgl/nn/pytorch/explain/pgexplainer.py
View file @
70ad5083
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
....
import
ETYPE
,
to_homogeneous
from
....
import
batch
,
ETYPE
,
khop_in_subgraph
,
NID
,
to_homogeneous
__all__
=
[
"PGExplainer"
,
"HeteroPGExplainer"
]
__all__
=
[
"PGExplainer"
,
"HeteroPGExplainer"
]
...
@@ -30,6 +30,11 @@ class PGExplainer(nn.Module):
...
@@ -30,6 +30,11 @@ class PGExplainer(nn.Module):
the intermediate node embeddings.
the intermediate node embeddings.
num_features : int
num_features : int
Node embedding size used by :attr:`model`.
Node embedding size used by :attr:`model`.
num_hops : int, optional
The number of hops for GNN information aggregation, which must match the
number of message passing layers employed by the GNN to be explained.
explain_graph : bool, optional
Whether to initialize the model for graph-level or node-level predictions.
coff_budget : float, optional
coff_budget : float, optional
Size regularization to constrain the explanation size. Default: 0.01.
Size regularization to constrain the explanation size. Default: 0.01.
coff_connect : float, optional
coff_connect : float, optional
...
@@ -43,6 +48,8 @@ class PGExplainer(nn.Module):
...
@@ -43,6 +48,8 @@ class PGExplainer(nn.Module):
self
,
self
,
model
,
model
,
num_features
,
num_features
,
num_hops
=
None
,
explain_graph
=
True
,
coff_budget
=
0.01
,
coff_budget
=
0.01
,
coff_connect
=
5e-4
,
coff_connect
=
5e-4
,
sample_bias
=
0.0
,
sample_bias
=
0.0
,
...
@@ -50,7 +57,9 @@ class PGExplainer(nn.Module):
...
@@ -50,7 +57,9 @@ class PGExplainer(nn.Module):
super
(
PGExplainer
,
self
).
__init__
()
super
(
PGExplainer
,
self
).
__init__
()
self
.
model
=
model
self
.
model
=
model
self
.
num_features
=
num_features
*
2
self
.
graph_explanation
=
explain_graph
self
.
num_features
=
num_features
*
(
2
if
self
.
graph_explanation
else
3
)
self
.
num_hops
=
num_hops
# training hyperparameters for PGExplainer
# training hyperparameters for PGExplainer
self
.
coff_budget
=
coff_budget
self
.
coff_budget
=
coff_budget
...
@@ -79,13 +88,14 @@ class PGExplainer(nn.Module):
...
@@ -79,13 +88,14 @@ class PGExplainer(nn.Module):
graph. The values are within range :math:`(0, 1)`. The higher,
graph. The values are within range :math:`(0, 1)`. The higher,
the more important. Default: None.
the more important. Default: None.
"""
"""
num_nodes
=
graph
.
num_nodes
()
num_edges
=
graph
.
num_edges
()
init_bias
=
self
.
init_bias
std
=
nn
.
init
.
calculate_gain
(
"relu"
)
*
math
.
sqrt
(
2.0
/
(
2
*
num_nodes
))
if
edge_mask
is
None
:
if
edge_mask
is
None
:
num_nodes
=
graph
.
num_nodes
()
num_edges
=
graph
.
num_edges
()
init_bias
=
self
.
init_bias
std
=
nn
.
init
.
calculate_gain
(
"relu"
)
*
math
.
sqrt
(
2.0
/
(
2
*
num_nodes
)
)
self
.
edge_mask
=
torch
.
randn
(
num_edges
)
*
std
+
init_bias
self
.
edge_mask
=
torch
.
randn
(
num_edges
)
*
std
+
init_bias
else
:
else
:
self
.
edge_mask
=
edge_mask
self
.
edge_mask
=
edge_mask
...
@@ -126,7 +136,7 @@ class PGExplainer(nn.Module):
...
@@ -126,7 +136,7 @@ class PGExplainer(nn.Module):
different types of label in the dataset and :math:`B` is
different types of label in the dataset and :math:`B` is
the batch size.
the batch size.
ori_pred: Tensor
ori_pred: Tensor
Tensor of shape
:
:math:`(B, 1)`, representing the original prediction
Tensor of shape :math:`(B, 1)`, representing the original prediction
for the graph, where :math:`B` is the batch size.
for the graph, where :math:`B` is the batch size.
Returns
Returns
...
@@ -216,17 +226,69 @@ class PGExplainer(nn.Module):
...
@@ -216,17 +226,69 @@ class PGExplainer(nn.Module):
Tensor
Tensor
A scalar tensor representing the loss.
A scalar tensor representing the loss.
"""
"""
assert
(
self
.
graph_explanation
),
'"explain_graph" must be True in initializing the module.'
self
.
model
=
self
.
model
.
to
(
graph
.
device
)
self
.
model
=
self
.
model
.
to
(
graph
.
device
)
self
.
elayers
=
self
.
elayers
.
to
(
graph
.
device
)
self
.
elayers
=
self
.
elayers
.
to
(
graph
.
device
)
pred
=
self
.
model
(
graph
,
feat
,
embed
=
False
,
**
kwargs
).
argmax
(
-
1
).
data
pred
=
self
.
model
(
graph
,
feat
,
embed
=
False
,
**
kwargs
)
pred
=
pred
.
argmax
(
-
1
).
data
prob
,
_
=
self
.
explain_graph
(
prob
,
_
=
self
.
explain_graph
(
graph
,
feat
,
tmp
=
tmp
,
training
=
True
,
**
kwargs
graph
,
feat
,
tmp
=
tmp
,
training
=
True
,
**
kwargs
)
)
loss_tmp
=
self
.
loss
(
prob
,
pred
)
loss
=
self
.
loss
(
prob
,
pred
)
return
loss_tmp
return
loss
def
train_step_node
(
self
,
nodes
,
graph
,
feat
,
tmp
,
**
kwargs
):
r
"""Compute the loss of the explanation network
Parameters
----------
nodes : int, iterable[int], tensor
The nodes from the graph used to train the explanation network, which cannot
have any duplicate value.
graph : DGLGraph
Input homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
tmp : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
A scalar tensor representing the loss.
"""
assert
(
not
self
.
graph_explanation
),
'"explain_graph" must be False in initializing the module.'
self
.
model
=
self
.
model
.
to
(
graph
.
device
)
self
.
elayers
=
self
.
elayers
.
to
(
graph
.
device
)
if
isinstance
(
nodes
,
torch
.
Tensor
):
nodes
=
nodes
.
tolist
()
if
isinstance
(
nodes
,
int
):
nodes
=
[
nodes
]
prob
,
_
,
batched_graph
,
inverse_indices
=
self
.
explain_node
(
nodes
,
graph
,
feat
,
tmp
=
tmp
,
training
=
True
,
**
kwargs
)
pred
=
self
.
model
(
batched_graph
,
self
.
batched_feats
,
embed
=
False
,
**
kwargs
)
pred
=
pred
.
argmax
(
-
1
).
data
loss
=
self
.
loss
(
prob
[
inverse_indices
],
pred
[
inverse_indices
])
return
loss
def
explain_graph
(
self
,
graph
,
feat
,
tmp
=
1.0
,
training
=
False
,
**
kwargs
):
def
explain_graph
(
self
,
graph
,
feat
,
tmp
=
1.0
,
training
=
False
,
**
kwargs
):
r
"""Learn and return an edge mask that plays a crucial role to
r
"""Learn and return an edge mask that plays a crucial role to
...
@@ -324,19 +386,20 @@ class PGExplainer(nn.Module):
...
@@ -324,19 +386,20 @@ class PGExplainer(nn.Module):
>>> graph_feat = graph.ndata.pop("attr")
>>> graph_feat = graph.ndata.pop("attr")
>>> probs, edge_weight = explainer.explain_graph(graph, graph_feat)
>>> probs, edge_weight = explainer.explain_graph(graph, graph_feat)
"""
"""
assert
(
self
.
graph_explanation
),
'"explain_graph" must be True in initializing the module.'
self
.
model
=
self
.
model
.
to
(
graph
.
device
)
self
.
model
=
self
.
model
.
to
(
graph
.
device
)
self
.
elayers
=
self
.
elayers
.
to
(
graph
.
device
)
self
.
elayers
=
self
.
elayers
.
to
(
graph
.
device
)
embed
=
self
.
model
(
graph
,
feat
,
embed
=
True
,
**
kwargs
)
embed
=
self
.
model
(
graph
,
feat
,
embed
=
True
,
**
kwargs
)
embed
=
embed
.
data
embed
=
embed
.
data
edge_idx
=
graph
.
edges
()
col
,
row
=
graph
.
edges
()
col
,
row
=
edge_idx
col_emb
=
embed
[
col
.
long
()]
col_emb
=
embed
[
col
.
long
()]
row_emb
=
embed
[
row
.
long
()]
row_emb
=
embed
[
row
.
long
()]
emb
=
torch
.
cat
([
col_emb
,
row_emb
],
dim
=-
1
)
emb
=
torch
.
cat
([
col_emb
,
row_emb
],
dim
=-
1
)
emb
=
self
.
elayers
(
emb
)
emb
=
self
.
elayers
(
emb
)
values
=
emb
.
reshape
(
-
1
)
values
=
emb
.
reshape
(
-
1
)
...
@@ -352,10 +415,188 @@ class PGExplainer(nn.Module):
...
@@ -352,10 +415,188 @@ class PGExplainer(nn.Module):
logits
=
self
.
model
(
graph
,
feat
,
edge_weight
=
self
.
edge_mask
,
**
kwargs
)
logits
=
self
.
model
(
graph
,
feat
,
edge_weight
=
self
.
edge_mask
,
**
kwargs
)
probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
if
not
training
:
if
training
:
probs
=
probs
.
data
else
:
self
.
clear_masks
()
return
(
probs
,
edge_mask
)
def
explain_node
(
self
,
nodes
,
graph
,
feat
,
tmp
=
1.0
,
training
=
False
,
**
kwargs
):
r
"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for node :attr:`node_id`.
Also, return the prediction made with the edges chosen based on
the edge mask.
Parameters
----------
nodes : int, iterable[int], tensor
The nodes from the graph, which cannot have any duplicate value.
graph : DGLGraph
A homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
tmp : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
Classification probabilities given the masked graph. It is a tensor of
shape :math:`(B, L)`, where :math:`L` is the different types of label
in the dataset, and :math:`B` is the batch size.
Tensor
Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`
is the number of edges in the graph. A higher weight suggests a larger
contribution of the edge.
DGLGraph
The batched set of subgraphs induced on the k-hop in-neighborhood
of the input center nodes.
Tensor
The new IDs of the subgraph center nodes.
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch
>>> # Define the model
>>> class Model(torch.nn.Module):
... def __init__(self, in_feats, out_feats):
... super().__init__()
... self.conv1 = dgl.nn.GraphConv(in_feats, out_feats)
... self.conv2 = dgl.nn.GraphConv(out_feats, out_feats)
...
... def forward(self, g, h, embed=False, edge_weight=None):
... h = self.conv1(g, h, edge_weight=edge_weight)
... if embed:
... return h
... return self.conv2(g, h)
>>> # Load dataset
>>> data = dgl.data.CoraGraphDataset(verbose=False)
>>> g = data[0]
>>> features = g.ndata["feat"]
>>> labels = g.ndata["label"]
>>> # Train the model
>>> model = Model(features.shape[1], data.num_classes)
>>> criterion = torch.nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for epoch in range(20):
... logits = model(g, features)
... loss = criterion(logits, labels)
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.PGExplainer(
... model, data.num_classes, num_hops=2, explain_graph=False
... )
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = torch.optim.Adam(explainer.parameters(), lr=0.01)
>>> epochs = 10
>>> for epoch in range(epochs):
... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / epochs))
... loss = explainer.train_step_node(g.nodes(), g, features, tmp)
... optimizer_exp.zero_grad()
... loss.backward()
... optimizer_exp.step()
>>> # Explain the prediction for graph 0
>>> probs, edge_weight, bg, inverse_indices = explainer.explain_node(
... 0, g, features
... )
"""
assert
(
not
self
.
graph_explanation
),
'"explain_graph" must be False in initializing the module.'
assert
(
self
.
num_hops
is
not
None
),
'"num_hops" must be provided in initializing the module.'
if
isinstance
(
nodes
,
torch
.
Tensor
):
nodes
=
nodes
.
tolist
()
if
isinstance
(
nodes
,
int
):
nodes
=
[
nodes
]
self
.
model
=
self
.
model
.
to
(
graph
.
device
)
self
.
elayers
=
self
.
elayers
.
to
(
graph
.
device
)
batched_graph
=
[]
batched_feats
=
[]
batched_embed
=
[]
batched_inverse_indices
=
[]
node_idx
=
0
for
node_id
in
nodes
:
sg
,
inverse_indices
=
khop_in_subgraph
(
graph
,
node_id
,
self
.
num_hops
)
sg_feat
=
feat
[
sg
.
ndata
[
NID
].
long
()]
embed
=
self
.
model
(
sg
,
sg_feat
,
embed
=
True
,
**
kwargs
)
embed
=
embed
.
data
col
,
row
=
sg
.
edges
()
col_emb
=
embed
[
col
.
long
()]
row_emb
=
embed
[
row
.
long
()]
self_emb
=
embed
[
inverse_indices
[
0
]].
repeat
(
sg
.
num_edges
(),
1
)
emb
=
torch
.
cat
([
col_emb
,
row_emb
,
self_emb
],
dim
=-
1
)
batched_embed
.
append
(
emb
)
batched_graph
.
append
(
sg
)
batched_feats
.
append
(
sg_feat
)
# node id's of subgraph mapped to batch:
# https://docs.dgl.ai/en/latest/generated/dgl.batch.html#dgl.batch
batched_inverse_indices
.
append
(
inverse_indices
[
0
].
item
()
+
node_idx
)
node_idx
+=
sg
.
num_nodes
()
batched_graph
=
batch
(
batched_graph
)
batched_feats
=
torch
.
cat
(
batched_feats
)
batched_embed
=
torch
.
cat
(
batched_embed
)
batched_embed
=
self
.
elayers
(
batched_embed
)
values
=
batched_embed
.
reshape
(
-
1
)
values
=
self
.
concrete_sample
(
values
,
beta
=
tmp
,
training
=
training
)
self
.
sparse_mask_values
=
values
col
,
row
=
batched_graph
.
edges
()
reverse_eids
=
batched_graph
.
edge_ids
(
row
,
col
).
long
()
edge_mask
=
(
values
+
values
[
reverse_eids
])
/
2
self
.
set_masks
(
batched_graph
,
edge_mask
)
# the model prediction with the updated edge mask
logits
=
self
.
model
(
batched_graph
,
batched_feats
,
edge_weight
=
self
.
edge_mask
,
**
kwargs
)
probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
if
training
:
self
.
batched_feats
=
batched_feats
probs
=
probs
.
data
else
:
self
.
clear_masks
()
self
.
clear_masks
()
return
(
probs
,
edge_mask
)
if
training
else
(
probs
.
data
,
edge_mask
)
return
(
probs
.
data
,
edge_mask
,
batched_graph
,
batched_inverse_indices
,
)
class
HeteroPGExplainer
(
PGExplainer
):
class
HeteroPGExplainer
(
PGExplainer
):
...
@@ -560,11 +801,9 @@ class HeteroPGExplainer(PGExplainer):
...
@@ -560,11 +801,9 @@ class HeteroPGExplainer(PGExplainer):
logits
=
self
.
model
(
graph
,
feat
,
edge_weight
=
hetero_edge_mask
,
**
kwargs
)
logits
=
self
.
model
(
graph
,
feat
,
edge_weight
=
hetero_edge_mask
,
**
kwargs
)
probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
if
not
training
:
if
training
:
probs
=
probs
.
data
else
:
self
.
clear_masks
()
self
.
clear_masks
()
return
(
return
(
probs
,
hetero_edge_mask
)
(
probs
,
hetero_edge_mask
)
if
training
else
(
probs
.
data
,
hetero_edge_mask
)
)
tests/python/pytorch/nn/test_nn.py
View file @
70ad5083
...
@@ -1826,8 +1826,9 @@ def test_pgexplainer(g, idtype, n_classes):
...
@@ -1826,8 +1826,9 @@ def test_pgexplainer(g, idtype, n_classes):
g
=
transform
(
g
)
g
=
transform
(
g
)
class
Model
(
th
.
nn
.
Module
):
class
Model
(
th
.
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
out_feats
):
def
__init__
(
self
,
in_feats
,
out_feats
,
graph
=
False
):
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
self
.
graph
=
graph
self
.
conv
=
nn
.
GraphConv
(
in_feats
,
out_feats
)
self
.
conv
=
nn
.
GraphConv
(
in_feats
,
out_feats
)
self
.
fc
=
th
.
nn
.
Linear
(
out_feats
,
out_feats
)
self
.
fc
=
th
.
nn
.
Linear
(
out_feats
,
out_feats
)
th
.
nn
.
init
.
xavier_uniform_
(
self
.
fc
.
weight
)
th
.
nn
.
init
.
xavier_uniform_
(
self
.
fc
.
weight
)
...
@@ -1835,7 +1836,7 @@ def test_pgexplainer(g, idtype, n_classes):
...
@@ -1835,7 +1836,7 @@ def test_pgexplainer(g, idtype, n_classes):
def
forward
(
self
,
g
,
h
,
embed
=
False
,
edge_weight
=
None
):
def
forward
(
self
,
g
,
h
,
embed
=
False
,
edge_weight
=
None
):
h
=
self
.
conv
(
g
,
h
,
edge_weight
=
edge_weight
)
h
=
self
.
conv
(
g
,
h
,
edge_weight
=
edge_weight
)
if
embed
:
if
not
self
.
graph
or
embed
:
return
h
return
h
with
g
.
local_scope
():
with
g
.
local_scope
():
...
@@ -1843,14 +1844,36 @@ def test_pgexplainer(g, idtype, n_classes):
...
@@ -1843,14 +1844,36 @@ def test_pgexplainer(g, idtype, n_classes):
hg
=
dgl
.
mean_nodes
(
g
,
"h"
)
hg
=
dgl
.
mean_nodes
(
g
,
"h"
)
return
self
.
fc
(
hg
)
return
self
.
fc
(
hg
)
model
=
Model
(
feat
.
shape
[
1
],
n_classes
)
# graph explainer
model
=
Model
(
feat
.
shape
[
1
],
n_classes
,
graph
=
True
)
model
=
model
.
to
(
ctx
)
model
=
model
.
to
(
ctx
)
explainer
=
nn
.
PGExplainer
(
model
,
n_classes
)
explainer
=
nn
.
PGExplainer
(
model
,
n_classes
)
explainer
.
train_step
(
g
,
g
.
ndata
[
"attr"
],
5.0
)
explainer
.
train_step
(
g
,
g
.
ndata
[
"attr"
],
5.0
)
probs
,
edge_weight
=
explainer
.
explain_graph
(
g
,
feat
)
probs
,
edge_weight
=
explainer
.
explain_graph
(
g
,
feat
)
# node explainer
model
=
Model
(
feat
.
shape
[
1
],
n_classes
,
graph
=
False
)
model
=
model
.
to
(
ctx
)
explainer
=
nn
.
PGExplainer
(
model
,
n_classes
,
num_hops
=
1
,
explain_graph
=
False
)
explainer
.
train_step_node
(
0
,
g
,
g
.
ndata
[
"attr"
],
5.0
)
explainer
.
train_step_node
([
0
,
1
],
g
,
g
.
ndata
[
"attr"
],
5.0
)
explainer
.
train_step_node
(
th
.
tensor
(
0
),
g
,
g
.
ndata
[
"attr"
],
5.0
)
explainer
.
train_step_node
(
th
.
tensor
([
0
,
1
]),
g
,
g
.
ndata
[
"attr"
],
5.0
)
probs
,
edge_weight
,
bg
,
inverse_indices
=
explainer
.
explain_node
(
0
,
g
,
feat
)
probs
,
edge_weight
,
bg
,
inverse_indices
=
explainer
.
explain_node
(
[
0
,
1
],
g
,
feat
)
probs
,
edge_weight
,
bg
,
inverse_indices
=
explainer
.
explain_node
(
th
.
tensor
(
0
),
g
,
feat
)
probs
,
edge_weight
,
bg
,
inverse_indices
=
explainer
.
explain_node
(
th
.
tensor
([
0
,
1
]),
g
,
feat
)
@
pytest
.
mark
.
parametrize
(
"g"
,
get_cases
([
"hetero"
]))
@
pytest
.
mark
.
parametrize
(
"g"
,
get_cases
([
"hetero"
]))
@
pytest
.
mark
.
parametrize
(
"idtype"
,
[
F
.
int64
])
@
pytest
.
mark
.
parametrize
(
"idtype"
,
[
F
.
int64
])
...
@@ -1901,9 +1924,10 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
...
@@ -1901,9 +1924,10 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
return
self
.
fc
(
hg
)
return
self
.
fc
(
hg
)
embed_dim
=
input_dim
embed_dim
=
input_dim
# graph explainer
model
=
Model
(
input_dim
,
embed_dim
,
n_classes
,
g
.
canonical_etypes
)
model
=
Model
(
input_dim
,
embed_dim
,
n_classes
,
g
.
canonical_etypes
)
model
=
model
.
to
(
ctx
)
model
=
model
.
to
(
ctx
)
explainer
=
nn
.
HeteroPGExplainer
(
model
,
embed_dim
)
explainer
=
nn
.
HeteroPGExplainer
(
model
,
embed_dim
)
explainer
.
train_step
(
g
,
feat
,
5.0
)
explainer
.
train_step
(
g
,
feat
,
5.0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment