Commit 9a0511c8 authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[NN] nn modules & examples update (#890)

* upd

* damn it

* fuck

* fuck pylint

* fudge

* remove some comments about MXNet

* upd

* upd

* damn it

* damn it

* fuck

* fuck

* upd

* upd

* pylint bastard

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd
parent 7f65199a
"""MXNet Module for Gated Graph Convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name, cell-var-from-loop
import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon import nn
from .... import function as fn
class GatedGraphConv(nn.Block):
r"""Gated Graph Convolution layer from paper `Gated Graph Sequence
Neural Networks <https://arxiv.org/pdf/1511.05493.pdf>`__.
.. math::
h_{i}^{0} & = [ x_i \| \mathbf{0} ]
a_{i}^{t} & = \sum_{j\in\mathcal{N}(i)} W_{e_{ij}} h_{j}^{t}
h_{i}^{t+1} & = \mathrm{GRU}(a_{i}^{t}, h_{i}^{t})
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
n_steps : int
Number of recurrent steps.
n_etypes : int
Number of edge types.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
Can only be set to True in MXNet.
"""
def __init__(self,
in_feats,
out_feats,
n_steps,
n_etypes,
bias=True):
super(GatedGraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._n_steps = n_steps
self._n_etypes = n_etypes
if not bias:
raise KeyError('MXNet do not support disabling bias in GRUCell.')
with self.name_scope():
self.linears = nn.Sequential()
for _ in range(n_etypes):
self.linears.add(
nn.Dense(out_feats,
weight_initializer=mx.init.Xavier(),
in_units=out_feats)
)
self.gru = gluon.rnn.GRUCell(out_feats, input_size=out_feats)
def forward(self, graph, feat, etypes):
"""Compute Gated Graph Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
input feature size.
etypes : torch.LongTensor
The edge type tensor of shape :math:`(E,)` where :math:`E` is
the number of edges of the graph.
Returns
-------
mxnet.NDArray
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size.
"""
graph = graph.local_var()
zero_pad = nd.zeros((feat.shape[0], self._out_feats - feat.shape[1]), ctx=feat.context)
feat = nd.concat(feat, zero_pad, dim=-1)
for _ in range(self._n_steps):
graph.ndata['h'] = feat
for i in range(self._n_etypes):
eids = (etypes.asnumpy() == i).nonzero()[0]
eids = nd.from_numpy(eids, zero_copy=True)
if len(eids) > 0:
graph.apply_edges(
lambda edges: {'W_e*h': self.linears[i](edges.src['h'])},
eids
)
graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a'))
a = graph.ndata.pop('a')
feat = self.gru(a, [feat])[0]
return feat
"""MXNet Module for Graph Isomorphism Network layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import mxnet as mx
from mxnet.gluon import nn
from .... import function as fn
class GINConv(nn.Block):
r"""Graph Isomorphism Network layer from paper `How Powerful are Graph
Neural Networks? <https://arxiv.org/pdf/1810.00826.pdf>`__.
.. math::
h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} +
\mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i)
\right\}\right)\right)
Parameters
----------
apply_func : callable activation function/layer or None
If not None, apply this function to the updated node feature,
the :math:`f_\Theta` in the formula.
aggregator_type : str
Aggregator type to use (``sum``, ``max`` or ``mean``).
init_eps : float, optional
Initial :math:`\epsilon` value, default: ``0``.
learn_eps : bool, optional
If True, :math:`\epsilon` will be a learnable parameter.
"""
def __init__(self,
apply_func,
aggregator_type,
init_eps=0,
learn_eps=False):
super(GINConv, self).__init__()
if aggregator_type == 'sum':
self._reducer = fn.sum
elif aggregator_type == 'max':
self._reducer = fn.max
elif aggregator_type == 'mean':
self._reducer = fn.mean
else:
raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))
with self.name_scope():
self.apply_func = apply_func
self.eps = self.params.get('eps',
shape=(1,),
grad_req='write' if learn_eps else 'null',
init=mx.init.Constant(init_eps))
def forward(self, graph, feat):
r"""Compute Graph Isomorphism Network layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D)` where :math:`D`
could be any positive integer, :math:`N` is the number
of nodes. If ``apply_func`` is not None, :math:`D` should
fit the input dimensionality requirement of ``apply_func``.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where
:math:`D_{out}` is the output dimensionality of ``apply_func``.
If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality.
"""
graph = graph.local_var()
graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
rst = (1 + self.eps.data(feat.context)) * feat + graph.ndata['neigh']
if self.apply_func is not None:
rst = self.apply_func(rst)
return rst
"""Torch Module for GMM Conv"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import Identity
from .... import function as fn
class GMMConv(nn.Block):
r"""The Gaussian Mixture Model Convolution layer from `Geometric Deep
Learning on Graphs and Manifolds using Mixture Model CNNs
<http://openaccess.thecvf.com/content_cvpr_2017/papers/Monti_Geometric_Deep_Learning_CVPR_2017_paper.pdf>`__.
.. math::
h_i^{l+1} & = \mathrm{aggregate}\left(\left\{\frac{1}{K}
\sum_{k}^{K} w_k(u_{ij}), \forall j\in \mathcal{N}(i)\right\}\right)
w_k(u) & = \exp\left(-\frac{1}{2}(u-\mu_k)^T \Sigma_k^{-1} (u - \mu_k)\right)
Parameters
----------
in_feats : int
Number of input features.
out_feats : int
Number of output features.
dim : int
Dimensionality of pseudo-coordinte.
n_kernels : int
Number of kernels :math:`K`.
aggregator_type : str
Aggregator type (``sum``, ``mean``, ``max``). Default: ``sum``.
residual : bool
If True, use residual connection inside this layer. Default: ``False``.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
"""
def __init__(self,
in_feats,
out_feats,
dim,
n_kernels,
aggregator_type='sum',
residual=False,
bias=True):
super(GMMConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._dim = dim
self._n_kernels = n_kernels
if aggregator_type == 'sum':
self._reducer = fn.sum
elif aggregator_type == 'mean':
self._reducer = fn.mean
elif aggregator_type == 'max':
self._reducer = fn.max
else:
raise KeyError("Aggregator type {} not recognized.".format(aggregator_type))
with self.name_scope():
self.mu = self.params.get('mu',
shape=(n_kernels, dim),
init=mx.init.Normal(0.1))
self.inv_sigma = self.params.get('inv_sigma',
shape=(n_kernels, dim),
init=mx.init.Constant(1))
self.fc = nn.Dense(n_kernels * out_feats,
in_units=in_feats,
use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)))
if residual:
if in_feats != out_feats:
self.res_fc = nn.Dense(out_feats, in_units=in_feats, use_bias=False)
else:
self.res_fc = Identity()
else:
self.res_fc = None
if bias:
self.bias = self.params.get('bias',
shape=(out_feats,),
init=mx.init.Zero())
else:
self.bias = None
def forward(self, graph, feat, pseudo):
"""Compute Gaussian Mixture Model Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
input feature size.
pseudo : mxnet.NDArray
The pseudo coordinate tensor of shape :math:`(E, D_{u})` where
:math:`E` is the number of edges of the graph and :math:`D_{u}`
is the dimensionality of pseudo coordinate.
Returns
-------
mxnet.NDArray
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size.
"""
graph = graph.local_var()
graph.ndata['h'] = self.fc(feat).reshape(-1, self._n_kernels, self._out_feats)
E = graph.number_of_edges()
# compute gaussian weight
gaussian = -0.5 * ((pseudo.reshape(E, 1, self._dim) -
self.mu.data(feat.context).reshape(1, self._n_kernels, self._dim)) ** 2)
gaussian = gaussian *\
(self.inv_sigma.data(feat.context).reshape(1, self._n_kernels, self._dim) ** 2)
gaussian = nd.exp(gaussian.sum(axis=-1, keepdims=True)) # (E, K, 1)
graph.edata['w'] = gaussian
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h'))
rst = graph.ndata['h'].sum(1)
# residual connection
if self.res_fc is not None:
rst = rst + self.res_fc(feat)
# bias
if self.bias is not None:
rst = rst + self.bias.data(feat.context)
return rst
"""MXNet Module for NNConv layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import mxnet as mx
from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import Identity
from .... import function as fn
class NNConv(nn.Block):
r"""Graph Convolution layer introduced in `Neural Message Passing
for Quantum Chemistry <https://arxiv.org/pdf/1704.01212.pdf>`__.
.. math::
h_{i}^{l+1} = h_{i}^{l} + \mathrm{aggregate}\left(\left\{
f_\Theta (e_{ij}) \cdot h_j^{l}, j\in \mathcal{N}(i) \right\}\right)
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
edge_func : callable activation function/layer
Maps each edge feature to a vector of shape
``(in_feats * out_feats)`` as weight to compute
messages.
Also is the :math:`f_\Theta` in the formula.
aggregator_type : str
Aggregator type to use (``sum``, ``mean`` or ``max``).
residual : bool, optional
If True, use residual connection. Default: ``False``.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
"""
def __init__(self,
in_feats,
out_feats,
edge_func,
aggregator_type,
residual=False,
bias=True):
super(NNConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
if aggregator_type == 'sum':
self.reducer = fn.sum
elif aggregator_type == 'mean':
self.reducer = fn.mean
elif aggregator_type == 'max':
self.reducer = fn.max
else:
raise KeyError('Aggregator type {} not recognized: '.format(aggregator_type))
self._aggre_type = aggregator_type
with self.name_scope():
self.edge_nn = edge_func
if residual:
if in_feats != out_feats:
self.res_fc = nn.Dense(out_feats, in_units=in_feats, use_bias=False,
weight_initializer=mx.init.Xavier())
else:
self.res_fc = Identity()
else:
self.res_fc = None
if bias:
self.bias = self.params.get('bias',
shape=(out_feats,),
init=mx.init.Zero())
else:
self.bias = None
def forward(self, graph, feat, efeat):
r"""Compute MPNN Graph Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
input feature size.
efeat : mxnet.NDArray
The edge feature of shape :math:`(N, *)`, should fit the input
shape requirement of ``edge_nn``.
Returns
-------
mxnet.NDArray
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size.
"""
graph = graph.local_var()
# (n, d_in, 1)
graph.ndata['h'] = feat.expand_dims(-1)
# (n, d_in, d_out)
graph.edata['w'] = self.edge_nn(efeat).reshape(-1, self._in_feats, self._out_feats)
# (n, d_in, d_out)
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh'))
rst = graph.ndata.pop('neigh').sum(axis=1) # (n, d_out)
# residual connection
if self.res_fc is not None:
rst = rst + self.res_fc(feat)
# bias
if self.bias is not None:
rst = rst + self.bias.data(feat.context)
return rst
"""MXNet Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn
from .... import function as fn
class SAGEConv(nn.Block):
r"""GraphSAGE layer from paper `Inductive Representation Learning on
Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__.
.. math::
h_{\mathcal{N}(i)}^{(l+1)} & = \mathrm{aggregate}
\left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)
h_{i}^{(l+1)} & = \sigma \left(W \cdot \mathrm{concat}
(h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1} + b) \right)
h_{i}^{(l+1)} & = \mathrm{norm}(h_{i}^{l})
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
feat_drop : float
Dropout rate on features, default: ``0``.
aggregator_type : str
Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
norm : callable activation function/layer or None, optional
If not None, applies normalization to the updated node features.
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
"""
def __init__(self,
in_feats,
out_feats,
aggregator_type='mean',
feat_drop=0.,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._aggre_type = aggregator_type
with self.name_scope():
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
if aggregator_type == 'pool':
self.fc_pool = nn.Dense(in_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats)
if aggregator_type == 'lstm':
raise NotImplementedError
if aggregator_type != 'gcn':
self.fc_self = nn.Dense(out_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats)
self.fc_neigh = nn.Dense(out_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats)
def forward(self, graph, feat):
r"""Compute GraphSAGE layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns
-------
mxnet.NDArray
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
graph = graph.local_var()
feat = self.feat_drop(feat)
h_self = feat
if self._aggre_type == 'mean':
graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.ndata['neigh']
elif self._aggre_type == 'gcn':
graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# divide in degrees
degs = graph.in_degrees().astype(feat.dtype)
degs = degs.as_in_context(feat.context)
h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.expand_dims(-1) + 1)
elif self._aggre_type == 'pool':
graph.ndata['h'] = nd.relu(self.fc_pool(feat))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.ndata['neigh']
elif self._aggre_type == 'lstm':
raise NotImplementedError
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self.norm is not None:
rst = self.norm(rst)
return rst
"""MXNet Module for Simplifying Graph Convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn
from .... import function as fn
class SGConv(nn.Block):
r"""Simplifying Graph Convolution layer from paper `Simplifying Graph
Convolutional Networks <https://arxiv.org/pdf/1902.07153.pdf>`__.
.. math::
H^{l+1} = (\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2})^K H^{l} \Theta^{l}
Parameters
----------
in_feats : int
Number of input features.
out_feats : int
Number of output features.
k : int
Number of hops :math:`K`. Defaults:``1``.
cached : bool
If True, the module would cache
.. math::
(\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}})^K X\Theta
at the first forward call. This parameter should only be set to
``True`` in Transductive Learning setting.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
norm : callable activation function/layer or None, optional
If not None, applies normalization to the updated node features.
"""
def __init__(self,
in_feats,
out_feats,
k=1,
cached=False,
bias=True,
norm=None):
super(SGConv, self).__init__()
self._cached = cached
self._cached_h = None
self._k = k
with self.name_scope():
self.norm = norm
self.fc = nn.Dense(out_feats, in_units=in_feats, use_bias=bias,
weight_initializer=mx.init.Xavier())
def forward(self, graph, feat):
r"""Compute Simplifying Graph Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns
-------
mxnet.NDArray
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
Notes
-----
If ``cache`` is se to True, ``feat`` and ``graph`` should not change during
training, or you will get wrong results.
"""
graph = graph.local_var()
if self._cached_h is not None:
feat = self._cached_h
else:
# compute normalization
degs = nd.clip(graph.in_degrees().astype(feat.dtype), 1, float('inf'))
norm = nd.power(degs, -0.5).expand_dims(1)
norm = norm.as_in_context(feat.context)
# compute (D^-1 A D)^k X
for _ in range(self._k):
feat = feat * norm
graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'),
fn.sum('m', 'h'))
feat = graph.ndata.pop('h')
feat = feat * norm
if self.norm is not None:
feat = self.norm(feat)
# cache feature
if self._cached:
self._cached_h = feat
return self.fc(feat)
...@@ -30,14 +30,14 @@ def matmul_maybe_select(A, B): ...@@ -30,14 +30,14 @@ def matmul_maybe_select(A, B):
Parameters Parameters
---------- ----------
A : torch.Tensor A : mxnet.NDArray
lhs tensor lhs tensor
B : torch.Tensor B : mxnet.NDArray
rhs tensor rhs tensor
Returns Returns
------- -------
C : torch.Tensor C : mxnet.NDArray
result tensor result tensor
""" """
if A.dtype in (np.int32, np.int64) and len(A.shape) == 1: if A.dtype in (np.int32, np.int64) and len(A.shape) == 1:
...@@ -67,16 +67,16 @@ def bmm_maybe_select(A, B, index): ...@@ -67,16 +67,16 @@ def bmm_maybe_select(A, B, index):
Parameters Parameters
---------- ----------
A : torch.Tensor A : mxnet.NDArray
lhs tensor lhs tensor
B : torch.Tensor B : mxnet.NDArray
rhs tensor rhs tensor
index : torch.Tensor index : mxnet.NDArray
index tensor index tensor
Returns Returns
------- -------
C : torch.Tensor C : mxnet.NDArray
return tensor return tensor
""" """
if A.dtype in (np.int32, np.int64) and len(A.shape) == 1: if A.dtype in (np.int32, np.int64) and len(A.shape) == 1:
...@@ -84,3 +84,24 @@ def bmm_maybe_select(A, B, index): ...@@ -84,3 +84,24 @@ def bmm_maybe_select(A, B, index):
else: else:
BB = nd.take(B, index, axis=0) BB = nd.take(B, index, axis=0)
return nd.batch_dot(A.expand_dims(1), BB).squeeze() return nd.batch_dot(A.expand_dims(1), BB).squeeze()
def normalize(x, p=2, axis=1, eps=1e-12):
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
:math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as
.. math::
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
With the default arguments it uses the Euclidean norm over vectors along dimension
:math:`1` for normalization.
Args:
x: input ndarray of any shape
ord (float): the exponent value in the norm formulation. Default: 2
dim (int): the dimension to reduce. Default: 1
eps (float): small value to avoid division by zero. Default: 1e-12
"""
denom = nd.clip(nd.norm(x, ord=p, axis=axis, keepdims=True), eps, float('inf'))
return x / denom
...@@ -58,8 +58,8 @@ class AGNNConv(nn.Module): ...@@ -58,8 +58,8 @@ class AGNNConv(nn.Module):
graph.ndata['h'] = feat graph.ndata['h'] = feat
graph.ndata['norm_h'] = F.normalize(feat, p=2, dim=-1) graph.ndata['norm_h'] = F.normalize(feat, p=2, dim=-1)
# compute cosine distance # compute cosine distance
graph.apply_edges(fn.u_mul_v('norm_h', 'norm_h', 'cos')) graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos'))
cos = graph.edata.pop('cos').sum(-1) cos = graph.edata.pop('cos')
e = self.beta * cos e = self.beta * cos
graph.edata['p'] = edge_softmax(graph, e) graph.edata['p'] = edge_softmax(graph, e)
graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h')) graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h'))
......
...@@ -4,7 +4,6 @@ import torch as th ...@@ -4,7 +4,6 @@ import torch as th
from torch import nn from torch import nn
from .... import function as fn from .... import function as fn
from ..utils import Identity
class APPNPConv(nn.Module): class APPNPConv(nn.Module):
...@@ -35,7 +34,7 @@ class APPNPConv(nn.Module): ...@@ -35,7 +34,7 @@ class APPNPConv(nn.Module):
super(APPNPConv, self).__init__() super(APPNPConv, self).__init__()
self._k = k self._k = k
self._alpha = alpha self._alpha = alpha
self.edge_drop = nn.Dropout(edge_drop) if edge_drop > 0 else Identity() self.edge_drop = nn.Dropout(edge_drop)
def forward(self, graph, feat): def forward(self, graph, feat):
r"""Compute APPNP layer. r"""Compute APPNP layer.
...@@ -56,10 +55,11 @@ class APPNPConv(nn.Module): ...@@ -56,10 +55,11 @@ class APPNPConv(nn.Module):
""" """
graph = graph.local_var() graph = graph.local_var()
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
norm = norm.unsqueeze(-1).to(feat.device) shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device)
feat_0 = feat feat_0 = feat
for _ in range(self._k): for _ in range(self._k):
# normalization by src # normalization by src node
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata['h'] = feat
graph.edata['w'] = self.edge_drop( graph.edata['w'] = self.edge_drop(
...@@ -67,7 +67,7 @@ class APPNPConv(nn.Module): ...@@ -67,7 +67,7 @@ class APPNPConv(nn.Module):
graph.update_all(fn.u_mul_e('h', 'w', 'm'), graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h')) fn.sum('m', 'h'))
feat = graph.ndata.pop('h') feat = graph.ndata.pop('h')
# normalization by dst # normalization by dst node
feat = feat * norm feat = feat * norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0 feat = (1 - self._alpha) * feat + self._alpha * feat_0
return feat return feat
...@@ -93,7 +93,7 @@ class ChebConv(nn.Module): ...@@ -93,7 +93,7 @@ class ChebConv(nn.Module):
lambda_max = laplacian_lambda_max(graph) lambda_max = laplacian_lambda_max(graph)
if isinstance(lambda_max, list): if isinstance(lambda_max, list):
lambda_max = th.Tensor(lambda_max).to(feat.device) lambda_max = th.Tensor(lambda_max).to(feat.device)
if lambda_max.dim() < 1: if lambda_max.dim() == 1:
lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1) lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1)
# broadcast from (B, 1) to (N, 1) # broadcast from (B, 1) to (N, 1)
lambda_max = broadcast_nodes(graph, lambda_max) lambda_max = broadcast_nodes(graph, lambda_max)
......
...@@ -73,7 +73,7 @@ class DenseSAGEConv(nn.Module): ...@@ -73,7 +73,7 @@ class DenseSAGEConv(nn.Module):
""" """
adj = adj.float().to(feat.device) adj = adj.float().to(feat.device)
feat = self.feat_drop(feat) feat = self.feat_drop(feat)
in_degrees = adj.sum(dim=1).unsqueeze(-1) in_degrees = adj.sum(dim=1, keepdim=True)
h_neigh = (adj @ feat + feat) / (in_degrees + 1) h_neigh = (adj @ feat + feat) / (in_degrees + 1)
rst = self.fc(h_neigh) rst = self.fc(h_neigh)
# activation # activation
......
...@@ -12,7 +12,6 @@ class EdgeConv(nn.Module): ...@@ -12,7 +12,6 @@ class EdgeConv(nn.Module):
<https://arxiv.org/pdf/1801.07829>`__". Can be described as follows: <https://arxiv.org/pdf/1801.07829>`__". Can be described as follows:
.. math:: .. math::
x_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} \mathrm{ReLU}( x_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} \mathrm{ReLU}(
\Theta \cdot (x_j^{(l)} - x_i^{(l)}) + \Phi \cdot x_i^{(l)}) \Theta \cdot (x_j^{(l)} - x_i^{(l)}) + \Phi \cdot x_i^{(l)})
...@@ -27,7 +26,10 @@ class EdgeConv(nn.Module): ...@@ -27,7 +26,10 @@ class EdgeConv(nn.Module):
batch_norm : bool batch_norm : bool
Whether to include batch normalization on messages. Whether to include batch normalization on messages.
""" """
def __init__(self, in_feat, out_feat, batch_norm=False): def __init__(self,
in_feat,
out_feat,
batch_norm=False):
super(EdgeConv, self).__init__() super(EdgeConv, self).__init__()
self.batch_norm = batch_norm self.batch_norm = batch_norm
......
"""Torch Module for Gated Graph Convolution layer""" """Torch Module for Gated Graph Convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name, cell-var-from-loop
import torch as th import torch as th
from torch import nn from torch import nn
from torch.nn import init from torch.nn import init
...@@ -41,7 +41,10 @@ class GatedGraphConv(nn.Module): ...@@ -41,7 +41,10 @@ class GatedGraphConv(nn.Module):
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._n_steps = n_steps self._n_steps = n_steps
self.edge_embed = nn.Embedding(n_etypes, out_feats * out_feats) self._n_etypes = n_etypes
self.linears = nn.ModuleList(
[nn.Linear(out_feats, out_feats) for _ in range(n_etypes)]
)
self.gru = nn.GRUCell(out_feats, out_feats, bias=bias) self.gru = nn.GRUCell(out_feats, out_feats, bias=bias)
self.reset_parameters() self.reset_parameters()
...@@ -49,7 +52,9 @@ class GatedGraphConv(nn.Module): ...@@ -49,7 +52,9 @@ class GatedGraphConv(nn.Module):
"""Reinitialize learnable parameters.""" """Reinitialize learnable parameters."""
gain = init.calculate_gain('relu') gain = init.calculate_gain('relu')
self.gru.reset_parameters() self.gru.reset_parameters()
init.xavier_normal_(self.edge_embed.weight, gain=gain) for linear in self.linears:
init.xavier_normal_(linear.weight, gain=gain)
init.zeros_(linear.bias)
def forward(self, graph, feat, etypes): def forward(self, graph, feat, etypes):
"""Compute Gated Graph Convolution layer. """Compute Gated Graph Convolution layer.
...@@ -75,13 +80,17 @@ class GatedGraphConv(nn.Module): ...@@ -75,13 +80,17 @@ class GatedGraphConv(nn.Module):
graph = graph.local_var() graph = graph.local_var()
zero_pad = feat.new_zeros((feat.shape[0], self._out_feats - feat.shape[1])) zero_pad = feat.new_zeros((feat.shape[0], self._out_feats - feat.shape[1]))
feat = th.cat([feat, zero_pad], -1) feat = th.cat([feat, zero_pad], -1)
# NOTE(zihao): there is still room to optimize, we may do kernel fusion
# for such operations in the future.
graph.edata['w'] = self.edge_embed(etypes).view(-1, self._out_feats, self._out_feats)
for _ in range(self._n_steps): for _ in range(self._n_steps):
graph.ndata['h'] = feat.unsqueeze(-1) # (N, D, 1) graph.ndata['h'] = feat
graph.update_all(fn.u_mul_e('h', 'w', 'm'), for i in range(self._n_etypes):
fn.sum('m', 'a')) eids = (etypes == i).nonzero().view(-1)
a = graph.ndata.pop('a').sum(dim=1) # (N, D) if len(eids) > 0:
graph.apply_edges(
lambda edges: {'W_e*h': self.linears[i](edges.src['h'])},
eids
)
graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a'))
a = graph.ndata.pop('a') # (N, D)
feat = self.gru(a, feat) feat = self.gru(a, feat)
return feat return feat
...@@ -32,7 +32,7 @@ class GMMConv(nn.Module): ...@@ -32,7 +32,7 @@ class GMMConv(nn.Module):
aggregator_type : str aggregator_type : str
Aggregator type (``sum``, ``mean``, ``max``). Aggregator type (``sum``, ``mean``, ``max``).
residual : bool residual : bool
If True, use residual connection inside this layer. If True, use residual connection inside this layer. Default: ``False``.
bias : bool bias : bool
If True, adds a learnable bias to the output. Default: ``True``. If True, adds a learnable bias to the output. Default: ``True``.
""" """
...@@ -41,8 +41,8 @@ class GMMConv(nn.Module): ...@@ -41,8 +41,8 @@ class GMMConv(nn.Module):
out_feats, out_feats,
dim, dim,
n_kernels, n_kernels,
aggregator_type, aggregator_type='sum',
residual=True, residual=False,
bias=True): bias=True):
super(GMMConv, self).__init__() super(GMMConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
...@@ -82,7 +82,7 @@ class GMMConv(nn.Module): ...@@ -82,7 +82,7 @@ class GMMConv(nn.Module):
if isinstance(self.res_fc, nn.Linear): if isinstance(self.res_fc, nn.Linear):
init.xavier_normal_(self.res_fc.weight, gain=gain) init.xavier_normal_(self.res_fc.weight, gain=gain)
init.normal_(self.mu.data, 0, 0.1) init.normal_(self.mu.data, 0, 0.1)
init.normal_(self.inv_sigma.data, 1, 0.1) init.constant_(self.inv_sigma.data, 1)
if self.bias is not None: if self.bias is not None:
init.zeros_(self.bias.data) init.zeros_(self.bias.data)
......
...@@ -47,6 +47,13 @@ class SGConv(nn.Module): ...@@ -47,6 +47,13 @@ class SGConv(nn.Module):
self._cached_h = None self._cached_h = None
self._k = k self._k = k
self.norm = norm self.norm = norm
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
nn.init.xavier_uniform_(self.fc.weight)
if self.fc.bias is not None:
nn.init.zeros_(self.fc.bias)
def forward(self, graph, feat): def forward(self, graph, feat):
r"""Compute Simplifying Graph Convolution layer. r"""Compute Simplifying Graph Convolution layer.
...@@ -77,9 +84,8 @@ class SGConv(nn.Module): ...@@ -77,9 +84,8 @@ class SGConv(nn.Module):
# compute normalization # compute normalization
degs = graph.in_degrees().float().clamp(min=1) degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
norm[th.isinf(norm)] = 0
norm = norm.to(feat.device).unsqueeze(1) norm = norm.to(feat.device).unsqueeze(1)
# compute (D^-1 A D) X # compute (D^-1 A^k D)^k X
for _ in range(self._k): for _ in range(self._k):
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata['h'] = feat
......
...@@ -112,6 +112,210 @@ def test_tagconv(): ...@@ -112,6 +112,210 @@ def test_tagconv():
h1 = conv(g, h0) h1 = conv(g, h0)
assert h1.shape[-1] == 2 assert h1.shape[-1] == 2
def test_gat_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()
gat = nn.GATConv(10, 20, 5) # n_heads = 5
gat.initialize(ctx=ctx)
print(gat)
# test#1: basic
h0 = F.randn((20, 10))
h1 = gat(g, h0)
assert h1.shape == (20, 5, 20)
def test_sage_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()
graphsage = nn.SAGEConv(10, 20)
graphsage.initialize(ctx=ctx)
print(graphsage)
# test#1: basic
h0 = F.randn((20, 10))
h1 = graphsage(g, h0)
assert h1.shape == (20, 20)
def test_gg_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()
gg_conv = nn.GatedGraphConv(10, 20, 3, 4) # n_step = 3, n_etypes = 4
gg_conv.initialize(ctx=ctx)
print(gg_conv)
# test#1: basic
h0 = F.randn((20, 10))
etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
h1 = gg_conv(g, h0, etypes)
assert h1.shape == (20, 20)
def test_cheb_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()
cheb = nn.ChebConv(10, 20, 3) # k = 3
cheb.initialize(ctx=ctx)
print(cheb)
# test#1: basic
h0 = F.randn((20, 10))
h1 = cheb(g, h0)
assert h1.shape == (20, 20)
def test_agnn_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()
agnn_conv = nn.AGNNConv(0.1, True)
agnn_conv.initialize(ctx=ctx)
print(agnn_conv)
# test#1: basic
h0 = F.randn((20, 10))
h1 = agnn_conv(g, h0)
assert h1.shape == (20, 10)
def test_appnp_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()
appnp_conv = nn.APPNPConv(3, 0.1, 0)
appnp_conv.initialize(ctx=ctx)
print(appnp_conv)
# test#1: basic
h0 = F.randn((20, 10))
h1 = appnp_conv(g, h0)
assert h1.shape == (20, 10)
def test_dense_cheb_conv():
for k in range(1, 4):
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.3), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).tostype('default')
cheb = nn.ChebConv(5, 2, k)
dense_cheb = nn.DenseChebConv(5, 2, k)
cheb.initialize(ctx=ctx)
dense_cheb.initialize(ctx=ctx)
for i in range(len(cheb.fc)):
dense_cheb.fc[i].weight.set_data(
cheb.fc[i].weight.data())
if cheb.bias is not None:
dense_cheb.bias.set_data(
cheb.bias.data())
feat = F.randn((100, 5))
out_cheb = cheb(g, feat, [2.0])
out_dense_cheb = dense_cheb(adj, feat, 2.0)
assert F.allclose(out_cheb, out_dense_cheb)
def test_dense_graph_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.3), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).tostype('default')
conv = nn.GraphConv(5, 2, norm=False, bias=True)
dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True)
conv.initialize(ctx=ctx)
dense_conv.initialize(ctx=ctx)
dense_conv.weight.set_data(
conv.weight.data())
dense_conv.bias.set_data(
conv.bias.data())
feat = F.randn((100, 5))
out_conv = conv(g, feat)
out_dense_conv = dense_conv(adj, feat)
assert F.allclose(out_conv, out_dense_conv)
def test_dense_sage_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).tostype('default')
sage = nn.SAGEConv(5, 2, 'gcn')
dense_sage = nn.DenseSAGEConv(5, 2)
sage.initialize(ctx=ctx)
dense_sage.initialize(ctx=ctx)
dense_sage.fc.weight.set_data(
sage.fc_neigh.weight.data())
dense_sage.fc.bias.set_data(
sage.fc_neigh.bias.data())
feat = F.randn((100, 5))
out_sage = sage(g, feat)
out_dense_sage = dense_sage(adj, feat)
assert F.allclose(out_sage, out_dense_sage)
def test_edge_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()
edge_conv = nn.EdgeConv(5, 2)
edge_conv.initialize(ctx=ctx)
print(edge_conv)
# test #1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), 2)
def test_gin_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()
gin_conv = nn.GINConv(lambda x: x, 'mean', 0.1)
gin_conv.initialize(ctx=ctx)
print(gin_conv)
# test #1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gin_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), 5)
def test_gmm_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()
gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
gmm_conv.initialize(ctx=ctx)
print(gmm_conv)
# test #1: basic
h0 = F.randn((g.number_of_nodes(), 5))
pseudo = F.randn((g.number_of_edges(), 5))
h1 = gmm_conv(g, h0, pseudo)
assert h1.shape == (g.number_of_nodes(), 2)
def test_nn_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()
nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max')
nn_conv.initialize(ctx=ctx)
print(nn_conv)
# test #1: basic
h0 = F.randn((g.number_of_nodes(), 5))
etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
h1 = nn_conv(g, h0, etypes)
assert h1.shape == (g.number_of_nodes(), 2)
def test_sg_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()
sgc = nn.SGConv(5, 2, 2)
sgc.initialize(ctx=ctx)
print(sgc)
# test #1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = sgc(g, h0)
assert h1.shape == (g.number_of_nodes(), 2)
def test_set2set(): def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10)) g = dgl.DGLGraph(nx.path_graph(10))
ctx = F.ctx() ctx = F.ctx()
...@@ -306,6 +510,20 @@ def test_rgcn(): ...@@ -306,6 +510,20 @@ def test_rgcn():
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_gat_conv()
test_sage_conv()
test_gg_conv()
test_cheb_conv()
test_agnn_conv()
test_appnp_conv()
test_dense_cheb_conv()
test_dense_graph_conv()
test_dense_sage_conv()
test_edge_conv()
test_gin_conv()
test_gmm_conv()
test_nn_conv()
test_sg_conv()
test_edge_softmax() test_edge_softmax()
test_partial_edge_softmax() test_partial_edge_softmax()
test_set2set() test_set2set()
......
...@@ -403,7 +403,6 @@ def test_gat_conv(): ...@@ -403,7 +403,6 @@ def test_gat_conv():
if F.gpu_ctx(): if F.gpu_ctx():
gat = gat.to(ctx) gat = gat.to(ctx)
feat = feat.to(ctx)
h = gat(g, feat) h = gat(g, feat)
assert h.shape[-1] == 2 and h.shape[-2] == 4 assert h.shape[-1] == 2 and h.shape[-2] == 4
...@@ -417,7 +416,6 @@ def test_sage_conv(): ...@@ -417,7 +416,6 @@ def test_sage_conv():
if F.gpu_ctx(): if F.gpu_ctx():
sage = sage.to(ctx) sage = sage.to(ctx)
feat = feat.to(ctx)
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 10 assert h.shape[-1] == 10
...@@ -431,7 +429,6 @@ def test_sgc_conv(): ...@@ -431,7 +429,6 @@ def test_sgc_conv():
if F.gpu_ctx(): if F.gpu_ctx():
sgc = sgc.to(ctx) sgc = sgc.to(ctx)
feat = feat.to(ctx)
h = sgc(g, feat) h = sgc(g, feat)
assert h.shape[-1] == 10 assert h.shape[-1] == 10
...@@ -455,7 +452,6 @@ def test_appnp_conv(): ...@@ -455,7 +452,6 @@ def test_appnp_conv():
if F.gpu_ctx(): if F.gpu_ctx():
appnp = appnp.to(ctx) appnp = appnp.to(ctx)
feat = feat.to(ctx)
h = appnp(g, feat) h = appnp(g, feat)
assert h.shape[-1] == 5 assert h.shape[-1] == 5
...@@ -472,7 +468,6 @@ def test_gin_conv(): ...@@ -472,7 +468,6 @@ def test_gin_conv():
if F.gpu_ctx(): if F.gpu_ctx():
gin = gin.to(ctx) gin = gin.to(ctx)
feat = feat.to(ctx)
h = gin(g, feat) h = gin(g, feat)
assert h.shape[-1] == 12 assert h.shape[-1] == 12
...@@ -485,7 +480,6 @@ def test_agnn_conv(): ...@@ -485,7 +480,6 @@ def test_agnn_conv():
if F.gpu_ctx(): if F.gpu_ctx():
agnn = agnn.to(ctx) agnn = agnn.to(ctx)
feat = feat.to(ctx)
h = agnn(g, feat) h = agnn(g, feat)
assert h.shape[-1] == 5 assert h.shape[-1] == 5
...@@ -499,7 +493,6 @@ def test_gated_graph_conv(): ...@@ -499,7 +493,6 @@ def test_gated_graph_conv():
if F.gpu_ctx(): if F.gpu_ctx():
ggconv = ggconv.to(ctx) ggconv = ggconv.to(ctx)
feat = feat.to(ctx)
etypes = etypes.to(ctx) etypes = etypes.to(ctx)
h = ggconv(g, feat, etypes) h = ggconv(g, feat, etypes)
...@@ -516,8 +509,6 @@ def test_nn_conv(): ...@@ -516,8 +509,6 @@ def test_nn_conv():
if F.gpu_ctx(): if F.gpu_ctx():
nnconv = nnconv.to(ctx) nnconv = nnconv.to(ctx)
feat = feat.to(ctx)
efeat = efeat.to(ctx)
h = nnconv(g, feat, efeat) h = nnconv(g, feat, efeat)
# currently we only do shape check # currently we only do shape check
...@@ -532,8 +523,6 @@ def test_gmm_conv(): ...@@ -532,8 +523,6 @@ def test_gmm_conv():
if F.gpu_ctx(): if F.gpu_ctx():
gmmconv = gmmconv.to(ctx) gmmconv = gmmconv.to(ctx)
feat = feat.to(ctx)
pseudo = pseudo.to(ctx)
h = gmmconv(g, feat, pseudo) h = gmmconv(g, feat, pseudo)
# currently we only do shape check # currently we only do shape check
...@@ -551,7 +540,6 @@ def test_dense_graph_conv(): ...@@ -551,7 +540,6 @@ def test_dense_graph_conv():
if F.gpu_ctx(): if F.gpu_ctx():
conv = conv.to(ctx) conv = conv.to(ctx)
dense_conv = dense_conv.to(ctx) dense_conv = dense_conv.to(ctx)
feat = feat.to(ctx)
out_conv = conv(g, feat) out_conv = conv(g, feat)
out_dense_conv = dense_conv(adj, feat) out_dense_conv = dense_conv(adj, feat)
...@@ -561,7 +549,7 @@ def test_dense_sage_conv(): ...@@ -561,7 +549,7 @@ def test_dense_sage_conv():
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).to_dense() adj = g.adjacency_matrix(ctx=ctx).to_dense()
sage = nn.SAGEConv(5, 2, 'gcn',) sage = nn.SAGEConv(5, 2, 'gcn')
dense_sage = nn.DenseSAGEConv(5, 2) dense_sage = nn.DenseSAGEConv(5, 2)
dense_sage.fc.weight.data = sage.fc_neigh.weight.data dense_sage.fc.weight.data = sage.fc_neigh.weight.data
dense_sage.fc.bias.data = sage.fc_neigh.bias.data dense_sage.fc.bias.data = sage.fc_neigh.bias.data
...@@ -569,7 +557,6 @@ def test_dense_sage_conv(): ...@@ -569,7 +557,6 @@ def test_dense_sage_conv():
if F.gpu_ctx(): if F.gpu_ctx():
sage = sage.to(ctx) sage = sage.to(ctx)
dense_sage = dense_sage.to(ctx) dense_sage = dense_sage.to(ctx)
feat = feat.to(ctx)
out_sage = sage(g, feat) out_sage = sage(g, feat)
out_dense_sage = dense_sage(adj, feat) out_dense_sage = dense_sage(adj, feat)
...@@ -590,7 +577,6 @@ def test_dense_cheb_conv(): ...@@ -590,7 +577,6 @@ def test_dense_cheb_conv():
if F.gpu_ctx(): if F.gpu_ctx():
cheb = cheb.to(ctx) cheb = cheb.to(ctx)
dense_cheb = dense_cheb.to(ctx) dense_cheb = dense_cheb.to(ctx)
feat = feat.to(ctx)
out_cheb = cheb(g, feat, [2.0]) out_cheb = cheb(g, feat, [2.0])
out_dense_cheb = dense_cheb(adj, feat, 2.0) out_dense_cheb = dense_cheb(adj, feat, 2.0)
......
Subproject commit 7ce90a342b0bda9b7f88e707a326496324d60efd Subproject commit 0f3ddbc7240efa05bfffd5bca808ec262ce3630e
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment