Unverified Commit 3fe5eea7 authored by rudongyu's avatar rudongyu Committed by GitHub
Browse files

[NN] Label Propagation & Directional Graph Networks (#4017)



* add label propagation module

* fix prev bug in example

* add dgn

* fix linting and doc issues

* update label propagation & dgn

* update label propagation & dgn

* update example

* fix unit test

* fix agg heritage issue

* fix agg issue

* fix lint

* fix idx

* fix lp gpu issue

* Update

* Update
Co-authored-by: default avatarmufeili <mufeili1996@gmail.com>
parent 6de7d5fa
...@@ -39,6 +39,7 @@ Conv Layers ...@@ -39,6 +39,7 @@ Conv Layers
~dgl.nn.pytorch.conv.GroupRevRes ~dgl.nn.pytorch.conv.GroupRevRes
~dgl.nn.pytorch.conv.EGNNConv ~dgl.nn.pytorch.conv.EGNNConv
~dgl.nn.pytorch.conv.PNAConv ~dgl.nn.pytorch.conv.PNAConv
~dgl.nn.pytorch.conv.DGNConv
Dense Conv Layers Dense Conv Layers
---------------------------------------- ----------------------------------------
...@@ -111,3 +112,4 @@ Utility Modules ...@@ -111,3 +112,4 @@ Utility Modules
~dgl.nn.pytorch.utils.JumpingKnowledge ~dgl.nn.pytorch.utils.JumpingKnowledge
~dgl.nn.pytorch.sparse_emb.NodeEmbedding ~dgl.nn.pytorch.sparse_emb.NodeEmbedding
~dgl.nn.pytorch.explain.GNNExplainer ~dgl.nn.pytorch.explain.GNNExplainer
~dgl.nn.pytorch.utils.LabelPropagation
...@@ -2,7 +2,7 @@ import argparse ...@@ -2,7 +2,7 @@ import argparse
import torch import torch
import dgl import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from model import LabelPropagation from dgl.nn import LabelPropagation
def main(): def main():
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
class LabelPropagation(nn.Module):
r"""
Description
-----------
Introduced in `Learning from Labeled and Unlabeled Data with Label Propagation <https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.3864&rep=rep1&type=pdf>`_
.. math::
\mathbf{Y}^{\prime} = \alpha \cdot \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2} \mathbf{Y} + (1 - \alpha) \mathbf{Y},
where unlabeled data is inferred by labeled data via propagation.
Parameters
----------
num_layers: int
The number of propagations.
alpha: float
The :math:`\alpha` coefficient.
"""
def __init__(self, num_layers, alpha):
super(LabelPropagation, self).__init__()
self.num_layers = num_layers
self.alpha = alpha
@torch.no_grad()
def forward(self, g, labels, mask=None, post_step=lambda y: y.clamp_(0., 1.)):
with g.local_scope():
if labels.dtype == torch.long:
labels = F.one_hot(labels.view(-1)).to(torch.float32)
y = labels
if mask is not None:
y = torch.zeros_like(labels)
y[mask] = labels[mask]
last = (1 - self.alpha) * y
degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5).to(labels.device).unsqueeze(1)
for _ in range(self.num_layers):
# Assume the graphs to be undirected
g.ndata['h'] = y * norm
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
y = last + self.alpha * g.ndata.pop('h') * norm
y = post_step(y)
last = (1 - self.alpha) * y
return y
...@@ -7,5 +7,5 @@ from .glob import * ...@@ -7,5 +7,5 @@ from .glob import *
from .softmax import * from .softmax import *
from .factory import * from .factory import *
from .hetero import * from .hetero import *
from .utils import Sequential, WeightBasis, JumpingKnowledge from .utils import Sequential, WeightBasis, JumpingKnowledge, LabelPropagation
from .sparse_emb import NodeEmbedding from .sparse_emb import NodeEmbedding
...@@ -30,10 +30,11 @@ from .hgtconv import HGTConv ...@@ -30,10 +30,11 @@ from .hgtconv import HGTConv
from .grouprevres import GroupRevRes from .grouprevres import GroupRevRes
from .egnnconv import EGNNConv from .egnnconv import EGNNConv
from .pnaconv import PNAConv from .pnaconv import PNAConv
from .dgnconv import DGNConv
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv', __all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv',
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GINEConv', 'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GINEConv',
'GatedGraphConv', 'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'GatedGraphConv', 'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv',
'DenseSAGEConv', 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'DenseSAGEConv', 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv',
'TWIRLSConv', 'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes', 'TWIRLSConv', 'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes',
'EGNNConv', 'PNAConv'] 'EGNNConv', 'PNAConv', 'DGNConv']
"""Torch Module for Directional Graph Networks Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from functools import partial
import torch
import torch.nn as nn
from .pnaconv import AGGREGATORS, SCALERS, PNAConv, PNAConvTower
def aggregate_dir_av(h, eig_s, eig_d, eig_idx):
"""directional average aggregation"""
h_mod = torch.mul(h, (
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]) /
(torch.sum(torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True, dim=1) + 1e-30)).unsqueeze(-1))
return torch.sum(h_mod, dim=1)
def aggregate_dir_dx(h, eig_s, eig_d, h_in, eig_idx):
"""directional derivative aggregation"""
eig_w = ((
eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]) /
(torch.sum(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True, dim=1) + 1e-30
)
).unsqueeze(-1)
h_mod = torch.mul(h, eig_w)
return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in)
for k in range(1, 4):
AGGREGATORS[f'dir{k}-av'] = partial(aggregate_dir_av, eig_idx=k-1)
AGGREGATORS[f'dir{k}-dx'] = partial(aggregate_dir_dx, eig_idx=k-1)
class DGNConvTower(PNAConvTower):
"""A single DGN tower with modified reduce function"""
def message(self, edges):
"""message function for DGN layer"""
if self.edge_feat_size > 0:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['a']], dim=-1)
else:
f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
return {'msg': self.M(f), 'eig_s': edges.src['eig'], 'eig_d': edges.dst['eig']}
def reduce_func(self, nodes):
"""reduce function for DGN layer"""
h_in = nodes.data['h']
eig_s = nodes.mailbox['eig_s']
eig_d = nodes.mailbox['eig_d']
msg = nodes.mailbox['msg']
degree = msg.size(1)
h = []
for agg in self.aggregators:
if agg.startswith('dir'):
if agg.endswith('av'):
h.append(AGGREGATORS[agg](msg, eig_s, eig_d))
else:
h.append(AGGREGATORS[agg](msg, eig_s, eig_d, h_in))
else:
h.append(AGGREGATORS[agg](msg))
h = torch.cat(h, dim=1)
h = torch.cat([
SCALERS[scaler](h, D=degree, delta=self.delta) if scaler != 'identity' else h
for scaler in self.scalers
], dim=1)
return {'h_neigh': h}
class DGNConv(PNAConv):
r"""Directional Graph Network Layer from `Directional Graph Networks
<https://arxiv.org/abs/2010.02863>`__
DGN introduces two special directional aggregators according to the vector field
:math:`F`, which is defined as the gradient of the low-frequency eigenvectors of graph
laplacian.
The directional average aggregator is defined as
:math:`h_i' = \sum_{j\in\mathcal{N}(i)}\frac{|F_{i,j}|\cdot h_j}{||F_{i,:}||_1+\epsilon}`
The directional derivative aggregator is defined as
:math:`h_i' = \sum_{j\in\mathcal{N}(i)}\frac{F_{i,j}\cdot h_j}{||F_{i,:}||_1+\epsilon}
-h_i\cdot\sum_{j\in\mathcal{N}(i)}\frac{F_{i,j}}{||F_{i,:}||_1+\epsilon}`
:math:`\epsilon` is the infinitesimal to keep the computation numerically stable.
Parameters
----------
in_size : int
Input feature size; i.e. the size of :math:`h_i^l`.
out_size : int
Output feature size; i.e. the size of :math:`h_i^{l+1}`.
aggregators : list of str
List of aggregation function names(each aggregator specifies a way to aggregate
messages from neighbours), selected from:
* ``mean``: the mean of neighbour messages
* ``max``: the maximum of neighbour messages
* ``min``: the minimum of neighbour messages
* ``std``: the standard deviation of neighbour messages
* ``var``: the variance of neighbour messages
* ``sum``: the sum of neighbour messages
* ``moment3``, ``moment4``, ``moment5``: the normalized moments aggregation
:math:`(E[(X-E[X])^n])^{1/n}`
* ``dir{k}-av``: directional average aggregation with directions defined by the k-th
smallest eigenvectors. k can be selected from 1, 2, 3.
* ``dir{k}-dx``: directional derivative aggregation with directions defined by the k-th
smallest eigenvectors. k can be selected from 1, 2, 3.
Note that using directional aggregation requires the LaplacianPE transform on the input
graph for eigenvector computation (the PE size must be >= k above).
scalers: list of str
List of scaler function names, selected from:
* ``identity``: no scaling
* ``amplification``: multiply the aggregated message by :math:`\log(d+1)/\delta`,
where :math:`d` is the in-degree of the node.
* ``attenuation``: multiply the aggregated message by :math:`\delta/\log(d+1)`
delta: float
The in-degree-related normalization factor computed over the training set, used by scalers
for normalization. :math:`E[\log(d+1)]`, where :math:`d` is the in-degree for each node
in the training set.
dropout: float, optional
The dropout ratio. Default: 0.0.
num_towers: int, optional
The number of towers used. Default: 1. Note that in_size and out_size must be divisible
by num_towers.
edge_feat_size: int, optional
The edge feature size. Default: 0.
residual : bool, optional
The bool flag that determines whether to add a residual connection for the
output. Default: True. If in_size and out_size of the DGN conv layer are not
the same, this flag will be set as False forcibly.
Example
-------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import DGNConv
>>> from dgl import LaplacianPE
>>>
>>> # DGN requires precomputed eigenvectors, with 'eig' as feature name.
>>> transform = LaplacianPE(k=3, feat_name='eig')
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = transform(g)
>>> eig = g.ndata['eig']
>>> feat = th.ones(6, 10)
>>> conv = DGNConv(10, 10, ['dir1-av', 'dir1-dx', 'sum'], ['identity', 'amplification'], 2.5)
>>> ret = conv(g, feat, eig_vec=eig)
"""
def __init__(self, in_size, out_size, aggregators, scalers, delta,
dropout=0., num_towers=1, edge_feat_size=0, residual=True):
super(DGNConv, self).__init__(
in_size, out_size, aggregators, scalers, delta, dropout,
num_towers, edge_feat_size, residual
)
self.towers = nn.ModuleList([
DGNConvTower(
self.tower_in_size, self.tower_out_size,
aggregators, scalers, delta,
dropout=dropout, edge_feat_size=edge_feat_size
) for _ in range(num_towers)
])
self.use_eig_vec = False
for aggr in aggregators:
if aggr.startswith('dir'):
self.use_eig_vec = True
break
def forward(self, graph, node_feat, edge_feat=None, eig_vec=None):
r"""
Description
-----------
Compute DGN layer.
Parameters
----------
graph : DGLGraph
The graph.
node_feat : torch.Tensor
The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of
nodes, and :math:`h_n` must be the same as in_size.
edge_feat : torch.Tensor, optional
The edge feature of shape :math:`(M, h_e)`. :math:`M` is the number of
edges, and :math:`h_e` must be the same as edge_feat_size.
eig_vec : torch.Tensor, optional
K smallest non-trivial eigenvectors of Graph Laplacian of shape :math:`(N, K)`.
It is only required when :attr:`aggregators` contains directional aggregators.
Returns
-------
torch.Tensor
The output node feature of shape :math:`(N, h_n')` where :math:`h_n'`
should be the same as out_size.
"""
with graph.local_scope():
if self.use_eig_vec:
graph.ndata['eig'] = eig_vec
return super().forward(graph, node_feat, edge_feat)
...@@ -95,9 +95,9 @@ class PNAConvTower(nn.Module): ...@@ -95,9 +95,9 @@ class PNAConvTower(nn.Module):
tensordot of multiple aggregation and scaling operations""" tensordot of multiple aggregation and scaling operations"""
msg = nodes.mailbox['msg'] msg = nodes.mailbox['msg']
degree = msg.size(1) degree = msg.size(1)
h = torch.cat([aggregator(msg) for aggregator in self.aggregators], dim=1) h = torch.cat([AGGREGATORS[agg](msg) for agg in self.aggregators], dim=1)
h = torch.cat([ h = torch.cat([
scaler(h, D=degree, delta=self.delta) if scaler is not scale_identity else h SCALERS[scaler](h, D=degree, delta=self.delta) if scaler != 'identity' else h
for scaler in self.scalers for scaler in self.scalers
], dim=1) ], dim=1)
return {'h_neigh': h} return {'h_neigh': h}
...@@ -213,8 +213,6 @@ class PNAConv(nn.Module): ...@@ -213,8 +213,6 @@ class PNAConv(nn.Module):
def __init__(self, in_size, out_size, aggregators, scalers, delta, def __init__(self, in_size, out_size, aggregators, scalers, delta,
dropout=0., num_towers=1, edge_feat_size=0, residual=True): dropout=0., num_towers=1, edge_feat_size=0, residual=True):
super(PNAConv, self).__init__() super(PNAConv, self).__init__()
aggregators = [AGGREGATORS[aggr] for aggr in aggregators]
scalers = [SCALERS[scale] for scale in scalers]
self.in_size = in_size self.in_size = in_size
self.out_size = out_size self.out_size = out_size
...@@ -254,7 +252,7 @@ class PNAConv(nn.Module): ...@@ -254,7 +252,7 @@ class PNAConv(nn.Module):
The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of
nodes, and :math:`h_n` must be the same as in_size. nodes, and :math:`h_n` must be the same as in_size.
edge_feat : torch.Tensor, optional edge_feat : torch.Tensor, optional
The edge feature of shape :math:`(M, h)`. :math:`M` is the number of The edge feature of shape :math:`(M, h_e)`. :math:`M` is the number of
edges, and :math:`h_e` must be the same as edge_feat_size. edges, and :math:`h_e` must be the same as edge_feat_size.
Returns Returns
......
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
import torch as th import torch as th
from torch import nn from torch import nn
import torch.nn.functional as F
from ... import DGLGraph from ... import DGLGraph
from ...base import dgl_warning from ...base import dgl_warning
from ... import function as fn
def matmul_maybe_select(A, B): def matmul_maybe_select(A, B):
"""Perform Matrix multiplication C = A * B but A could be an integer id vector. """Perform Matrix multiplication C = A * B but A could be an integer id vector.
...@@ -395,3 +397,125 @@ class JumpingKnowledge(nn.Module): ...@@ -395,3 +397,125 @@ class JumpingKnowledge(nn.Module):
alpha = self.att(alpha).squeeze(-1) # (N, num_layers) alpha = self.att(alpha).squeeze(-1) # (N, num_layers)
alpha = th.softmax(alpha, dim=-1) alpha = th.softmax(alpha, dim=-1)
return (stacked_feat_list * alpha.unsqueeze(-1)).sum(dim=1) return (stacked_feat_list * alpha.unsqueeze(-1)).sum(dim=1)
class LabelPropagation(nn.Module):
r"""Label Propagation from `Learning from Labeled and Unlabeled Data with Label
Propagation <http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf>`__
.. math::
\mathbf{Y}^{(t+1)} = \alpha \cdot \tilde{A} \mathbf{Y}^{(t)} + (1 - \alpha)
\mathbf{Y}^{(0)},
where unlabeled data is initially set to zero and inferred from labeled data via
propagation. :math:`\alpha` is a weight parameter for balancing between updated labels
and initial labels. :math:`\tilde{A}` denotes the normalized adjacency matrix.
Parameters
----------
k: int
The number of propagation steps.
alpha : float
The :math:`\alpha` coefficient in range [0, 1].
norm_type : str, optional
The type of normalization applied to the adjacency matrix, must be one of the
following choices:
* ``row``: row-normalized adjacency as :math:`D^{-1}A`
* ``sym``: symmetrically normalized adjacency as :math:`D^{-1/2}AD^{-1/2}`
Default: 'sym'.
clamp : bool, optional
A bool flag to indicate whether to clamp the labels to [0, 1] after propagation.
Default: True.
normalize: bool, optional
A bool flag to indicate whether to apply row-normalization after propagation.
Default: False.
reset : bool, optional
A bool flag to indicate whether to reset the known labels after each
propagation step. Default: False.
Examples
--------
>>> import torch
>>> import dgl
>>> from dgl.nn import LabelPropagation
>>> label_propagation = LabelPropagation(k=5, alpha=0.5, clamp=False, normalize=True)
>>> g = dgl.rand_graph(5, 10)
>>> labels = torch.tensor([0, 2, 1, 3, 0]).long()
>>> mask = torch.tensor([0, 1, 1, 1, 0]).bool()
>>> new_labels = label_propagation(g, labels, mask)
"""
def __init__(self, k, alpha, norm_type='sym', clamp=True, normalize=False, reset=False):
super(LabelPropagation, self).__init__()
self.k = k
self.alpha = alpha
self.norm_type = norm_type
self.clamp = clamp
self.normalize = normalize
self.reset = reset
def forward(self, g, labels, mask=None):
r"""Compute the label propagation process.
Parameters
----------
g : DGLGraph
The input graph.
labels : torch.Tensor
The input node labels. There are three cases supported.
* A LongTensor of shape :math:`(N, 1)` or :math:`(N,)` for node class labels in
multiclass classification, where :math:`N` is the number of nodes.
* A LongTensor of shape :math:`(N, C)` for one-hot encoding of node class labels
in multiclass classification, where :math:`C` is the number of classes.
* A LongTensor of shape :math:`(N, L)` for node labels in multilabel binary
classification, where :math:`L` is the number of labels.
mask : torch.Tensor
The bool indicators of shape :math:`(N,)` with True denoting labeled nodes.
Default: None, indicating all nodes are labeled.
Returns
-------
torch.Tensor
The propagated node labels of shape :math:`(N, D)` with float type, where :math:`D`
is the number of classes or labels.
"""
with g.local_scope():
# multi-label / multi-class
if len(labels.size()) > 1 and labels.size(1) > 1:
labels = labels.to(th.float32)
# single-label multi-class
else:
labels = F.one_hot(labels.view(-1)).to(th.float32)
y = labels
if mask is not None:
y = th.zeros_like(labels)
y[mask] = labels[mask]
init = (1 - self.alpha) * y
in_degs = g.in_degrees().float().clamp(min=1)
out_degs = g.out_degrees().float().clamp(min=1)
if self.norm_type == 'sym':
norm_i = th.pow(in_degs, -0.5).to(labels.device).unsqueeze(1)
norm_j = th.pow(out_degs, -0.5).to(labels.device).unsqueeze(1)
elif self.norm_type == 'row':
norm_i = th.pow(in_degs, -1.).to(labels.device).unsqueeze(1)
else:
raise ValueError(f"Expect norm_type to be 'sym' or 'row', got {self.norm_type}")
for _ in range(self.k):
g.ndata['h'] = y * norm_j if self.norm_type == 'sym' else y
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
y = init + self.alpha * g.ndata['h'] * norm_i
if self.clamp:
y = y.clamp_(0., 1.)
if self.normalize:
y = F.normalize(y, p=1)
if self.reset:
y[mask] = labels[mask]
return y
...@@ -1483,3 +1483,53 @@ def test_pna_conv(in_size, out_size, aggregators, scalers, delta, ...@@ -1483,3 +1483,53 @@ def test_pna_conv(in_size, out_size, aggregators, scalers, delta,
model = nn.PNAConv(in_size, out_size, aggregators, scalers, delta, dropout, model = nn.PNAConv(in_size, out_size, aggregators, scalers, delta, dropout,
num_towers, edge_feat_size, residual).to(dev) num_towers, edge_feat_size, residual).to(dev)
model(g, h, edge_feat=e) model(g, h, edge_feat=e)
@pytest.mark.parametrize('k', [3, 5])
@pytest.mark.parametrize('alpha', [0., 0.5, 1.])
@pytest.mark.parametrize('norm_type', ['sym', 'row'])
@pytest.mark.parametrize('clamp', [True, False])
@pytest.mark.parametrize('normalize', [True, False])
@pytest.mark.parametrize('reset', [True, False])
def test_label_prop(k, alpha, norm_type, clamp, normalize, reset):
dev = F.ctx()
num_nodes = 5
num_edges = 20
num_classes = 4
g = dgl.rand_graph(num_nodes, num_edges).to(dev)
labels = th.tensor([0, 2, 1, 3, 0]).long().to(dev)
ml_labels = th.rand(num_nodes, num_classes).to(dev) > 0.7
mask = th.tensor([0, 1, 1, 1, 0]).bool().to(dev)
model = nn.LabelPropagation(k, alpha, norm_type, clamp, normalize, reset)
model(g, labels, mask)
# multi-label case
model(g, ml_labels, mask)
@pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('out_size', [16, 32])
@pytest.mark.parametrize('aggregators',
[['mean', 'max', 'dir2-av'], ['min', 'std', 'dir1-dx'], ['moment3', 'moment4', 'dir3-av']])
@pytest.mark.parametrize('scalers', [['identity'], ['amplification', 'attenuation']])
@pytest.mark.parametrize('delta', [2.5, 7.4])
@pytest.mark.parametrize('dropout', [0., 0.1])
@pytest.mark.parametrize('num_towers', [1, 4])
@pytest.mark.parametrize('edge_feat_size', [16, 0])
@pytest.mark.parametrize('residual', [True, False])
def test_dgn_conv(in_size, out_size, aggregators, scalers, delta,
dropout, num_towers, edge_feat_size, residual):
dev = F.ctx()
num_nodes = 5
num_edges = 20
g = dgl.rand_graph(num_nodes, num_edges).to(dev)
h = th.randn(num_nodes, in_size).to(dev)
e = th.randn(num_edges, edge_feat_size).to(dev)
transform = dgl.LaplacianPE(k=3, feat_name='eig')
g = transform(g)
eig = g.ndata['eig']
model = nn.DGNConv(in_size, out_size, aggregators, scalers, delta, dropout,
num_towers, edge_feat_size, residual).to(dev)
model(g, h, edge_feat=e, eig_vec=eig)
aggregators_non_eig = [aggr for aggr in aggregators if not aggr.startswith('dir')]
model = nn.DGNConv(in_size, out_size, aggregators_non_eig, scalers, delta, dropout,
num_towers, edge_feat_size, residual).to(dev)
model(g, h, edge_feat=e)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment