Unverified Commit af61e2fb authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Support nn modules for bipartite graphs. (#1392)



* init gat

* fix

* gin

* 7 nn modules

* rename & lint

* upd

* upd

* fix lint

* upd test

* upd

* lint

* shape check

* upd

* lint

* address comments

* update tensorflow
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 67cb7a43
...@@ -3992,6 +3992,10 @@ class DGLGraph(DGLBaseGraph): ...@@ -3992,6 +3992,10 @@ class DGLGraph(DGLBaseGraph):
self._node_frame = old_nframe self._node_frame = old_nframe
self._edge_frame = old_eframe self._edge_frame = old_eframe
def is_homograph(self):
"""Return if the graph is homogeneous."""
return True
############################################################ ############################################################
# Batch/Unbatch APIs # Batch/Unbatch APIs
############################################################ ############################################################
......
...@@ -4106,6 +4106,10 @@ class DGLHeteroGraph(object): ...@@ -4106,6 +4106,10 @@ class DGLHeteroGraph(object):
self._node_frames = old_nframes self._node_frames = old_nframes
self._edge_frames = old_eframes self._edge_frames = old_eframes
def is_homograph(self):
"""Return if the graph is homogeneous."""
return len(self.ntypes) == 1 and len(self.etypes) == 1
############################################################ ############################################################
# Internal APIs # Internal APIs
############################################################ ############################################################
......
...@@ -6,6 +6,8 @@ from mxnet.gluon import nn ...@@ -6,6 +6,8 @@ from mxnet.gluon import nn
from .... import function as fn from .... import function as fn
from ..softmax import edge_softmax from ..softmax import edge_softmax
from ..utils import normalize from ..utils import normalize
from ....utils import expand_as_pair
class AGNNConv(nn.Block): class AGNNConv(nn.Block):
r"""Attention-based Graph Neural Network layer from paper `Attention-based r"""Attention-based Graph Neural Network layer from paper `Attention-based
...@@ -47,6 +49,9 @@ class AGNNConv(nn.Block): ...@@ -47,6 +49,9 @@ class AGNNConv(nn.Block):
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature of shape :math:`(N, *)` :math:`N` is the The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape. number of nodes, and :math:`*` could be of any shape.
If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape
:math:`(N_{in}, *)` and :math:`(N_{out}, *})`, the the :math:`*` in the later
tensor must equal the previous one.
Returns Returns
------- -------
...@@ -55,12 +60,16 @@ class AGNNConv(nn.Block): ...@@ -55,12 +60,16 @@ class AGNNConv(nn.Block):
should be the same as input shape. should be the same as input shape.
""" """
graph = graph.local_var() graph = graph.local_var()
graph.ndata['h'] = feat
graph.ndata['norm_h'] = normalize(feat, p=2, axis=-1) feat_src, feat_dst = expand_as_pair(feat)
graph.srcdata['h'] = feat_src
graph.srcdata['norm_h'] = normalize(feat_src, p=2, axis=-1)
if isinstance(feat, tuple):
graph.dstdata['norm_h'] = normalize(feat_dst, p=2, axis=-1)
# compute cosine distance # compute cosine distance
graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos')) graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos'))
cos = graph.edata.pop('cos') cos = graph.edata.pop('cos')
e = self.beta.data(feat.context) * cos e = self.beta.data(feat_src.context) * cos
graph.edata['p'] = edge_softmax(graph, e) graph.edata['p'] = edge_softmax(graph, e)
graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h')) graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h'))
return graph.ndata.pop('h') return graph.dstdata.pop('h')
...@@ -18,8 +18,11 @@ class DenseGraphConv(nn.Block): ...@@ -18,8 +18,11 @@ class DenseGraphConv(nn.Block):
Input feature size. Input feature size.
out_feats : int out_feats : int
Output feature size. Output feature size.
norm : bool norm : str, optional
If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``. How to apply the normalizer. If is `'right'`, divide the aggregated messages
by each node's in-degrees, which is equivalent to averaging the received messages.
If is `'none'`, no normalization is applied. Default is `'both'`,
where the :math:`c_{ij}` in the paper is applied.
bias : bool bias : bool
If True, adds a learnable bias to the output. Default: ``True``. If True, adds a learnable bias to the output. Default: ``True``.
activation : callable activation function/layer or None, optional activation : callable activation function/layer or None, optional
...@@ -33,7 +36,7 @@ class DenseGraphConv(nn.Block): ...@@ -33,7 +36,7 @@ class DenseGraphConv(nn.Block):
def __init__(self, def __init__(self,
in_feats, in_feats,
out_feats, out_feats,
norm=True, norm='both',
bias=True, bias=True,
activation=None): activation=None):
super(DenseGraphConv, self).__init__() super(DenseGraphConv, self).__init__()
...@@ -56,12 +59,14 @@ class DenseGraphConv(nn.Block): ...@@ -56,12 +59,14 @@ class DenseGraphConv(nn.Block):
Parameters Parameters
---------- ----------
adj : mxnet.NDArray adj : mxnet.NDArray
The adjacency matrix of the graph to apply Graph Convolution on, The adjacency matrix of the graph to apply Graph Convolution on, when
should be of shape :math:`(N, N)`, where a row represents the destination applied to a unidirectional bipartite graph, ``adj`` should be of shape
and a column represents the source. should be of shape :math:`(N_{out}, N_{in})`; when applied to a homo
feat : mxnet.NDArray graph, ``adj`` should be of shape :math:`(N, N)`. In both cases,
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` a row represents a destination node while a column represents a source
is size of input feature, :math:`N` is the number of nodes. node.
feat : torch.Tensor
The input feature.
Returns Returns
------- -------
...@@ -70,24 +75,33 @@ class DenseGraphConv(nn.Block): ...@@ -70,24 +75,33 @@ class DenseGraphConv(nn.Block):
is size of output feature. is size of output feature.
""" """
adj = adj.astype(feat.dtype).as_in_context(feat.context) adj = adj.astype(feat.dtype).as_in_context(feat.context)
if self._norm: src_degrees = nd.clip(adj.sum(axis=0), a_min=1, a_max=float('inf'))
in_degrees = adj.sum(axis=1) dst_degrees = nd.clip(adj.sum(axis=1), a_min=1, a_max=float('inf'))
norm = nd.power(in_degrees, -0.5) feat_src = feat
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context) if self._norm == 'both':
feat = feat * norm norm_src = nd.power(src_degrees, -0.5)
shp_src = norm_src.shape + (1,) * (feat.ndim - 1)
norm_src = norm_src.reshape(shp_src).as_in_context(feat.context)
feat_src = feat_src * norm_src
if self._in_feats > self._out_feats: if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
feat = nd.dot(feat, self.weight.data(feat.context)) feat_src = nd.dot(feat_src, self.weight.data(feat_src.context))
rst = nd.dot(adj, feat) rst = nd.dot(adj, feat_src)
else: else:
# aggregate first then mult W # aggregate first then mult W
rst = nd.dot(adj, feat) rst = nd.dot(adj, feat_src)
rst = nd.dot(rst, self.weight.data(feat.context)) rst = nd.dot(rst, self.weight.data(feat_src.context))
if self._norm: if self._norm != 'none':
rst = rst * norm if self._norm == 'both':
norm_dst = nd.power(dst_degrees, -0.5)
else: # right
norm_dst = 1.0 / dst_degrees
shp_dst = norm_dst.shape + (1,) * (feat.ndim - 1)
norm_dst = norm_dst.reshape(shp_dst).as_in_context(feat.context)
rst = rst * norm_dst
if self.bias is not None: if self.bias is not None:
rst = rst + self.bias.data(feat.context) rst = rst + self.bias.data(feat.context)
......
...@@ -4,6 +4,7 @@ import math ...@@ -4,6 +4,7 @@ import math
import mxnet as mx import mxnet as mx
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
from ....utils import check_eq_shape
class DenseSAGEConv(nn.Block): class DenseSAGEConv(nn.Block):
...@@ -56,12 +57,18 @@ class DenseSAGEConv(nn.Block): ...@@ -56,12 +57,18 @@ class DenseSAGEConv(nn.Block):
Parameters Parameters
---------- ----------
adj : mxnet.NDArray adj : mxnet.NDArray
The adjacency matrix of the graph to apply Graph Convolution on, The adjacency matrix of the graph to apply SAGE Convolution on, when
should be of shape :math:`(N, N)`, where a row represents the destination applied to a unidirectional bipartite graph, ``adj`` should be of shape
and a column represents the source. should be of shape :math:`(N_{out}, N_{in})`; when applied to a homo
feat : mxnet.NDArray graph, ``adj`` should be of shape :math:`(N, N)`. In both cases,
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` a row represents a destination node while a column represents a source
is size of input feature, :math:`N` is the number of nodes. node.
feat : mxnet.NDArray or a pair of mxnet.NDArray
If a mxnet.NDArray is given, the input feature of shape :math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of
nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of
shape :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
Returns Returns
------- -------
...@@ -69,10 +76,15 @@ class DenseSAGEConv(nn.Block): ...@@ -69,10 +76,15 @@ class DenseSAGEConv(nn.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
adj = adj.astype(feat.dtype).as_in_context(feat.context) check_eq_shape(feat)
feat = self.feat_drop(feat) if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
adj = adj.astype(feat_src.dtype).as_in_context(feat_src.context)
in_degrees = adj.sum(axis=1, keepdims=True) in_degrees = adj.sum(axis=1, keepdims=True)
h_neigh = (nd.dot(adj, feat) + feat) / (in_degrees + 1) h_neigh = (nd.dot(adj, feat_src) + feat_dst) / (in_degrees + 1)
rst = self.fc(h_neigh) rst = self.fc(h_neigh)
# activation # activation
if self.activation is not None: if self.activation is not None:
......
...@@ -4,6 +4,7 @@ import mxnet as mx ...@@ -4,6 +4,7 @@ import mxnet as mx
from mxnet.gluon import nn from mxnet.gluon import nn
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair
class EdgeConv(nn.Block): class EdgeConv(nn.Block):
...@@ -60,17 +61,23 @@ class EdgeConv(nn.Block): ...@@ -60,17 +61,23 @@ class EdgeConv(nn.Block):
h : mxnet.NDArray h : mxnet.NDArray
:math:`(N, D)` where :math:`N` is the number of nodes and :math:`(N, D)` where :math:`N` is the number of nodes and
:math:`D` is the number of feature dimensions. :math:`D` is the number of feature dimensions.
If a pair of tensors is given, the graph must be a uni-bipartite graph
with only one edge type, and the two tensors must have the same
dimensionality on all except the first axis.
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
New node features. New node features.
""" """
with g.local_scope(): with g.local_scope():
g.ndata['x'] = h h_src, h_dst = expand_as_pair(h)
g.srcdata['x'] = h_src
g.dstdata['x'] = h_dst
if not self.batch_norm: if not self.batch_norm:
g.update_all(self.message, fn.max('e', 'x')) g.update_all(self.message, fn.max('e', 'x'))
else: else:
g.apply_edges(self.message) g.apply_edges(self.message)
g.edata['e'] = self.bn(g.edata['e']) g.edata['e'] = self.bn(g.edata['e'])
g.update_all(fn.copy_e('e', 'm'), fn.max('m', 'x')) g.update_all(fn.copy_e('e', 'm'), fn.max('m', 'x'))
return g.ndata['x'] return g.dstdata['x']
...@@ -7,6 +7,7 @@ from mxnet.gluon.contrib.nn import Identity ...@@ -7,6 +7,7 @@ from mxnet.gluon.contrib.nn import Identity
from .... import function as fn from .... import function as fn
from ..softmax import edge_softmax from ..softmax import edge_softmax
from ....utils import expand_as_pair
#pylint: enable=W0235 #pylint: enable=W0235
class GATConv(nn.Block): class GATConv(nn.Block):
...@@ -26,8 +27,13 @@ class GATConv(nn.Block): ...@@ -26,8 +27,13 @@ class GATConv(nn.Block):
Parameters Parameters
---------- ----------
in_feats : int in_feats : int or pair of ints
Input feature size. Input feature size.
If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
out_feats : int out_feats : int
Output feature size. Output feature size.
num_heads : int num_heads : int
...@@ -55,12 +61,21 @@ class GATConv(nn.Block): ...@@ -55,12 +61,21 @@ class GATConv(nn.Block):
activation=None): activation=None):
super(GATConv, self).__init__() super(GATConv, self).__init__()
self._num_heads = num_heads self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
with self.name_scope(): with self.name_scope():
self.fc = nn.Dense(out_feats * num_heads, use_bias=False, if isinstance(in_feats, tuple):
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), self.fc_src = nn.Dense(out_feats * num_heads, use_bias=False,
in_units=in_feats) weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_src_feats)
self.fc_dst = nn.Dense(out_feats * num_heads, use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_dst_feats)
else:
self.fc = nn.Dense(out_feats * num_heads, use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats)
self.attn_l = self.params.get('attn_l', self.attn_l = self.params.get('attn_l',
shape=(1, num_heads, out_feats), shape=(1, num_heads, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0))) init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
...@@ -90,8 +105,11 @@ class GATConv(nn.Block): ...@@ -90,8 +105,11 @@ class GATConv(nn.Block):
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` If a mxnet.NDArray is given, the input feature of shape :math:`(N, D_{in})`
is size of input feature, :math:`N` is the number of nodes. where :math:`D_{in}` is size of input feature, :math:`N` is the number of
nodes.
If a pair of mxnet.NDArray is given, the pair must contain two tensors of
shape :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
Returns Returns
------- -------
...@@ -100,8 +118,17 @@ class GATConv(nn.Block): ...@@ -100,8 +118,17 @@ class GATConv(nn.Block):
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
""" """
graph = graph.local_var() graph = graph.local_var()
h = self.feat_drop(feat) if isinstance(feat, tuple):
feat = self.fc(h).reshape(-1, self._num_heads, self._out_feats) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
feat_src = self.fc_src(h_src).reshape(
-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).reshape(
-1, self._num_heads, self._out_feats)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).reshape(
-1, self._num_heads, self._out_feats)
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
...@@ -112,9 +139,10 @@ class GATConv(nn.Block): ...@@ -112,9 +139,10 @@ class GATConv(nn.Block):
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v, # addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
el = (feat * self.attn_l.data(feat.context)).sum(axis=-1).expand_dims(-1) el = (feat_src * self.attn_l.data(feat_src.context)).sum(axis=-1).expand_dims(-1)
er = (feat * self.attn_r.data(feat.context)).sum(axis=-1).expand_dims(-1) er = (feat_dst * self.attn_r.data(feat_src.context)).sum(axis=-1).expand_dims(-1)
graph.ndata.update({'ft': feat, 'el': el, 'er': er}) graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e')) e = self.leaky_relu(graph.edata.pop('e'))
...@@ -122,10 +150,10 @@ class GATConv(nn.Block): ...@@ -122,10 +150,10 @@ class GATConv(nn.Block):
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft')) fn.sum('m', 'ft'))
rst = graph.ndata['ft'] rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = self.res_fc(h).reshape(h.shape[0], -1, self._out_feats) resval = self.res_fc(h_dst).reshape(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
......
...@@ -75,6 +75,8 @@ class GatedGraphConv(nn.Block): ...@@ -75,6 +75,8 @@ class GatedGraphConv(nn.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size. is the output feature size.
""" """
assert graph.is_homograph(), \
"not a homograph; convert it with to_homo and pass in the edge type as argument"
graph = graph.local_var() graph = graph.local_var()
zero_pad = nd.zeros((feat.shape[0], self._out_feats - feat.shape[1]), ctx=feat.context) zero_pad = nd.zeros((feat.shape[0], self._out_feats - feat.shape[1]), ctx=feat.context)
feat = nd.concat(feat, zero_pad, dim=-1) feat = nd.concat(feat, zero_pad, dim=-1)
......
...@@ -4,6 +4,7 @@ import mxnet as mx ...@@ -4,6 +4,7 @@ import mxnet as mx
from mxnet.gluon import nn from mxnet.gluon import nn
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair
class GINConv(nn.Block): class GINConv(nn.Block):
...@@ -56,24 +57,28 @@ class GINConv(nn.Block): ...@@ -56,24 +57,28 @@ class GINConv(nn.Block):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : mxnet.NDArray or a pair of mxnet.NDArray
The input feature of shape :math:`(N, D)` where :math:`D` If a mxnet.NDArray is given, the input feature of shape :math:`(N, D_{in})`
could be any positive integer, :math:`N` is the number where :math:`D_{in}` is size of input feature, :math:`N` is the number of
of nodes. If ``apply_func`` is not None, :math:`D` should nodes.
If a pair of mxnet.NDArray is given, the pair must contain two tensors of
shape :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
If ``apply_func`` is not None, :math:`D_{in}` should
fit the input dimensionality requirement of ``apply_func``. fit the input dimensionality requirement of ``apply_func``.
Returns Returns
------- -------
torch.Tensor mxnet.NDArray
The output feature of shape :math:`(N, D_{out})` where The output feature of shape :math:`(N, D_{out})` where
:math:`D_{out}` is the output dimensionality of ``apply_func``. :math:`D_{out}` is the output dimensionality of ``apply_func``.
If ``apply_func`` is None, :math:`D_{out}` should be the same If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality. as input dimensionality.
""" """
graph = graph.local_var() graph = graph.local_var()
graph.ndata['h'] = feat feat_src, feat_dst = expand_as_pair(feat)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
rst = (1 + self.eps.data(feat.context)) * feat + graph.ndata['neigh'] rst = (1 + self.eps.data(feat_dst.context)) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
return rst return rst
...@@ -7,6 +7,7 @@ from mxnet.gluon import nn ...@@ -7,6 +7,7 @@ from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import Identity from mxnet.gluon.contrib.nn import Identity
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair
class GMMConv(nn.Block): class GMMConv(nn.Block):
...@@ -22,8 +23,13 @@ class GMMConv(nn.Block): ...@@ -22,8 +23,13 @@ class GMMConv(nn.Block):
Parameters Parameters
---------- ----------
in_feats : int in_feats : int, or pair of ints
Number of input features. Number of input features.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
out_feats : int out_feats : int
Number of output features. Number of output features.
dim : int dim : int
...@@ -46,7 +52,8 @@ class GMMConv(nn.Block): ...@@ -46,7 +52,8 @@ class GMMConv(nn.Block):
residual=False, residual=False,
bias=True): bias=True):
super(GMMConv, self).__init__() super(GMMConv, self).__init__()
self._in_feats = in_feats
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
self._dim = dim self._dim = dim
self._n_kernels = n_kernels self._n_kernels = n_kernels
...@@ -67,12 +74,12 @@ class GMMConv(nn.Block): ...@@ -67,12 +74,12 @@ class GMMConv(nn.Block):
shape=(n_kernels, dim), shape=(n_kernels, dim),
init=mx.init.Constant(1)) init=mx.init.Constant(1))
self.fc = nn.Dense(n_kernels * out_feats, self.fc = nn.Dense(n_kernels * out_feats,
in_units=in_feats, in_units=self._in_src_feats,
use_bias=False, use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0))) weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)))
if residual: if residual:
if in_feats != out_feats: if self._in_dst_feats != out_feats:
self.res_fc = nn.Dense(out_feats, in_units=in_feats, use_bias=False) self.res_fc = nn.Dense(out_feats, in_units=self._in_dst_feats, use_bias=False)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
...@@ -93,9 +100,10 @@ class GMMConv(nn.Block): ...@@ -93,9 +100,10 @@ class GMMConv(nn.Block):
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`N` If a single tensor is given, the input feature of shape :math:`(N, D_{in})` where
is the number of nodes of the graph and :math:`D_{in}` is the :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
input feature size. If a pair of tensors are given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
pseudo : mxnet.NDArray pseudo : mxnet.NDArray
The pseudo coordinate tensor of shape :math:`(E, D_{u})` where The pseudo coordinate tensor of shape :math:`(E, D_{u})` where
:math:`E` is the number of edges of the graph and :math:`D_{u}` :math:`E` is the number of edges of the graph and :math:`D_{u}`
...@@ -107,22 +115,26 @@ class GMMConv(nn.Block): ...@@ -107,22 +115,26 @@ class GMMConv(nn.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size. is the output feature size.
""" """
graph = graph.local_var() feat_src, feat_dst = expand_as_pair(feat)
graph.ndata['h'] = self.fc(feat).reshape(-1, self._n_kernels, self._out_feats) with graph.local_scope():
E = graph.number_of_edges() graph.srcdata['h'] = self.fc(feat_src).reshape(
# compute gaussian weight -1, self._n_kernels, self._out_feats)
gaussian = -0.5 * ((pseudo.reshape(E, 1, self._dim) - E = graph.number_of_edges()
self.mu.data(feat.context).reshape(1, self._n_kernels, self._dim)) ** 2) # compute gaussian weight
gaussian = gaussian *\ gaussian = -0.5 * ((pseudo.reshape(E, 1, self._dim) -
(self.inv_sigma.data(feat.context).reshape(1, self._n_kernels, self._dim) ** 2) self.mu.data(feat_src.context)
gaussian = nd.exp(gaussian.sum(axis=-1, keepdims=True)) # (E, K, 1) .reshape(1, self._n_kernels, self._dim)) ** 2)
graph.edata['w'] = gaussian gaussian = gaussian *\
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h')) (self.inv_sigma.data(feat_src.context)
rst = graph.ndata['h'].sum(1) .reshape(1, self._n_kernels, self._dim) ** 2)
# residual connection gaussian = nd.exp(gaussian.sum(axis=-1, keepdims=True)) # (E, K, 1)
if self.res_fc is not None: graph.edata['w'] = gaussian
rst = rst + self.res_fc(feat) graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h'))
# bias rst = graph.dstdata['h'].sum(1)
if self.bias is not None: # residual connection
rst = rst + self.bias.data(feat.context) if self.res_fc is not None:
return rst rst = rst + self.res_fc(feat_dst)
# bias
if self.bias is not None:
rst = rst + self.bias.data(feat_dst.context)
return rst
...@@ -110,7 +110,7 @@ class GraphConv(gluon.Block): ...@@ -110,7 +110,7 @@ class GraphConv(gluon.Block):
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature The input feature.
weight : torch.Tensor, optional weight : torch.Tensor, optional
Optional external weight tensor. Optional external weight tensor.
......
...@@ -5,6 +5,7 @@ from mxnet.gluon import nn ...@@ -5,6 +5,7 @@ from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import Identity from mxnet.gluon.contrib.nn import Identity
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair
class NNConv(nn.Block): class NNConv(nn.Block):
...@@ -17,8 +18,13 @@ class NNConv(nn.Block): ...@@ -17,8 +18,13 @@ class NNConv(nn.Block):
Parameters Parameters
---------- ----------
in_feats : int in_feats : int or pair of ints
Input feature size. Input feature size.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
out_feats : int out_feats : int
Output feature size. Output feature size.
edge_func : callable activation function/layer edge_func : callable activation function/layer
...@@ -41,7 +47,7 @@ class NNConv(nn.Block): ...@@ -41,7 +47,7 @@ class NNConv(nn.Block):
residual=False, residual=False,
bias=True): bias=True):
super(NNConv, self).__init__() super(NNConv, self).__init__()
self._in_feats = in_feats self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
if aggregator_type == 'sum': if aggregator_type == 'sum':
self.reducer = fn.sum self.reducer = fn.sum
...@@ -56,9 +62,10 @@ class NNConv(nn.Block): ...@@ -56,9 +62,10 @@ class NNConv(nn.Block):
with self.name_scope(): with self.name_scope():
self.edge_nn = edge_func self.edge_nn = edge_func
if residual: if residual:
if in_feats != out_feats: if self._in_dst_feats != out_feats:
self.res_fc = nn.Dense(out_feats, in_units=in_feats, use_bias=False, self.res_fc = nn.Dense(
weight_initializer=mx.init.Xavier()) out_feats, in_units=self._in_dst_feats,
use_bias=False, weight_initializer=mx.init.Xavier())
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
...@@ -78,7 +85,7 @@ class NNConv(nn.Block): ...@@ -78,7 +85,7 @@ class NNConv(nn.Block):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray or pair of mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`N` The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the is the number of nodes of the graph and :math:`D_{in}` is the
input feature size. input feature size.
...@@ -92,18 +99,20 @@ class NNConv(nn.Block): ...@@ -92,18 +99,20 @@ class NNConv(nn.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size. is the output feature size.
""" """
graph = graph.local_var() with graph.local_scope():
# (n, d_in, 1) feat_src, feat_dst = expand_as_pair(feat)
graph.ndata['h'] = feat.expand_dims(-1)
# (n, d_in, d_out) # (n, d_in, 1)
graph.edata['w'] = self.edge_nn(efeat).reshape(-1, self._in_feats, self._out_feats) graph.srcdata['h'] = feat_src.expand_dims(-1)
# (n, d_in, d_out) # (n, d_in, d_out)
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh')) graph.edata['w'] = self.edge_nn(efeat).reshape(-1, self._in_src_feats, self._out_feats)
rst = graph.ndata.pop('neigh').sum(axis=1) # (n, d_out) # (n, d_in, d_out)
# residual connection graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh'))
if self.res_fc is not None: rst = graph.dstdata.pop('neigh').sum(axis=1) # (n, d_out)
rst = rst + self.res_fc(feat) # residual connection
# bias if self.res_fc is not None:
if self.bias is not None: rst = rst + self.res_fc(feat_dst)
rst = rst + self.bias.data(feat.context) # bias
return rst if self.bias is not None:
rst = rst + self.bias.data(feat_dst.context)
return rst
...@@ -175,25 +175,27 @@ class RelGraphConv(gluon.Block): ...@@ -175,25 +175,27 @@ class RelGraphConv(gluon.Block):
mx.ndarray.NDArray mx.ndarray.NDArray
New node features. New node features.
""" """
g = g.local_var() assert g.is_homograph(), \
g.ndata['h'] = x "not a homograph; convert it with to_homo and pass in the edge type as argument"
g.edata['type'] = etypes with g.local_scope():
if norm is not None: g.ndata['h'] = x
g.edata['norm'] = norm g.edata['type'] = etypes
if self.self_loop: if norm is not None:
loop_message = utils.matmul_maybe_select(x, self.loop_weight.data(x.context)) g.edata['norm'] = norm
if self.self_loop:
# message passing loop_message = utils.matmul_maybe_select(x, self.loop_weight.data(x.context))
g.update_all(self.message_func, fn.sum(msg='msg', out='h'))
# message passing
# apply bias and activation g.update_all(self.message_func, fn.sum(msg='msg', out='h'))
node_repr = g.ndata['h']
if self.bias: # apply bias and activation
node_repr = node_repr + self.h_bias.data(x.context) node_repr = g.ndata['h']
if self.self_loop: if self.bias:
node_repr = node_repr + loop_message node_repr = node_repr + self.h_bias.data(x.context)
if self.activation: if self.self_loop:
node_repr = self.activation(node_repr) node_repr = node_repr + loop_message
node_repr = self.dropout(node_repr) if self.activation:
node_repr = self.activation(node_repr)
return node_repr node_repr = self.dropout(node_repr)
return node_repr
"""MXNet Module for GraphSAGE layer""" """MXNet Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import math import math
from numbers import Integral
import mxnet as mx import mxnet as mx
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair, check_eq_shape
class SAGEConv(nn.Block): class SAGEConv(nn.Block):
r"""GraphSAGE layer from paper `Inductive Representation Learning on r"""GraphSAGE layer from paper `Inductive Representation Learning on
...@@ -57,14 +57,7 @@ class SAGEConv(nn.Block): ...@@ -57,14 +57,7 @@ class SAGEConv(nn.Block):
activation=None): activation=None):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
if isinstance(in_feats, tuple): self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._in_src_feats = in_feats[0]
self._in_dst_feats = in_feats[1]
elif isinstance(in_feats, Integral):
self._in_src_feats = self._in_dst_feats = in_feats
else:
raise TypeError('in_feats must be either int or pair of ints')
self._out_feats = out_feats self._out_feats = out_feats
self._aggre_type = aggregator_type self._aggre_type = aggregator_type
with self.name_scope(): with self.name_scope():
...@@ -92,9 +85,11 @@ class SAGEConv(nn.Block): ...@@ -92,9 +85,11 @@ class SAGEConv(nn.Block):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray or pair of mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` If a single tensor is given, the input feature of shape :math:`(N, D_{in})` where
is size of input feature, :math:`N` is the number of nodes. :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tensors are given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
Returns Returns
------- -------
...@@ -117,6 +112,7 @@ class SAGEConv(nn.Block): ...@@ -117,6 +112,7 @@ class SAGEConv(nn.Block):
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn': elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # saame as above if homogeneous graph.dstdata['h'] = feat_dst # saame as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
......
...@@ -76,6 +76,7 @@ class TAGConv(gluon.Block): ...@@ -76,6 +76,7 @@ class TAGConv(gluon.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
assert graph.is_homograph(), 'Graph is not homogeneous'
graph = graph.local_var() graph = graph.local_var()
degs = graph.in_degrees().astype('float32') degs = graph.in_degrees().astype('float32')
......
...@@ -6,6 +6,7 @@ from torch.nn import functional as F ...@@ -6,6 +6,7 @@ from torch.nn import functional as F
from .... import function as fn from .... import function as fn
from ..softmax import edge_softmax from ..softmax import edge_softmax
from ....utils import expand_as_pair
class AGNNConv(nn.Module): class AGNNConv(nn.Module):
...@@ -47,6 +48,9 @@ class AGNNConv(nn.Module): ...@@ -47,6 +48,9 @@ class AGNNConv(nn.Module):
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, *)` :math:`N` is the The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape. number of nodes, and :math:`*` could be of any shape.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, *)` and :math:`(N_{out}, *})`, the the :math:`*` in the later
tensor must equal the previous one.
Returns Returns
------- -------
...@@ -55,12 +59,16 @@ class AGNNConv(nn.Module): ...@@ -55,12 +59,16 @@ class AGNNConv(nn.Module):
should be the same as input shape. should be the same as input shape.
""" """
graph = graph.local_var() graph = graph.local_var()
graph.ndata['h'] = feat
graph.ndata['norm_h'] = F.normalize(feat, p=2, dim=-1) feat_src, feat_dst = expand_as_pair(feat)
graph.srcdata['h'] = feat_src
graph.srcdata['norm_h'] = F.normalize(feat_src, p=2, dim=-1)
if isinstance(feat, tuple):
graph.dstdata['norm_h'] = F.normalize(feat_dst, p=2, dim=-1)
# compute cosine distance # compute cosine distance
graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos')) graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos'))
cos = graph.edata.pop('cos') cos = graph.edata.pop('cos')
e = self.beta * cos e = self.beta * cos
graph.edata['p'] = edge_softmax(graph, e) graph.edata['p'] = edge_softmax(graph, e)
graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h')) graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h'))
return graph.ndata.pop('h') return graph.dstdata.pop('h')
...@@ -17,8 +17,11 @@ class DenseGraphConv(nn.Module): ...@@ -17,8 +17,11 @@ class DenseGraphConv(nn.Module):
Input feature size. Input feature size.
out_feats : int out_feats : int
Output feature size. Output feature size.
norm : bool norm : str, optional
If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``. How to apply the normalizer. If is `'right'`, divide the aggregated messages
by each node's in-degrees, which is equivalent to averaging the received messages.
If is `'none'`, no normalization is applied. Default is `'both'`,
where the :math:`c_{ij}` in the paper is applied.
bias : bool bias : bool
If True, adds a learnable bias to the output. Default: ``True``. If True, adds a learnable bias to the output. Default: ``True``.
activation : callable activation function/layer or None, optional activation : callable activation function/layer or None, optional
...@@ -32,7 +35,7 @@ class DenseGraphConv(nn.Module): ...@@ -32,7 +35,7 @@ class DenseGraphConv(nn.Module):
def __init__(self, def __init__(self,
in_feats, in_feats,
out_feats, out_feats,
norm=True, norm='both',
bias=True, bias=True,
activation=None): activation=None):
super(DenseGraphConv, self).__init__() super(DenseGraphConv, self).__init__()
...@@ -60,12 +63,14 @@ class DenseGraphConv(nn.Module): ...@@ -60,12 +63,14 @@ class DenseGraphConv(nn.Module):
Parameters Parameters
---------- ----------
adj : torch.Tensor adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on, The adjacency matrix of the graph to apply Graph Convolution on, when
should be of shape :math:`(N, N)`, where a row represents the destination applied to a unidirectional bipartite graph, ``adj`` should be of shape
and a column represents the source. should be of shape :math:`(N_{out}, N_{in})`; when applied to a homo
graph, ``adj`` should be of shape :math:`(N, N)`. In both cases,
a row represents a destination node while a column represents a source
node.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature.
is size of input feature, :math:`N` is the number of nodes.
Returns Returns
------- -------
...@@ -74,24 +79,33 @@ class DenseGraphConv(nn.Module): ...@@ -74,24 +79,33 @@ class DenseGraphConv(nn.Module):
is size of output feature. is size of output feature.
""" """
adj = adj.float().to(feat.device) adj = adj.float().to(feat.device)
if self._norm: src_degrees = adj.sum(dim=0).clamp(min=1)
in_degrees = adj.sum(dim=1) dst_degrees = adj.sum(dim=1).clamp(min=1)
norm = th.pow(in_degrees, -0.5) feat_src = feat
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device) if self._norm == 'both':
feat = feat * norm norm_src = th.pow(src_degrees, -0.5)
shp = norm_src.shape + (1,) * (feat.dim() - 1)
norm_src = th.reshape(norm_src, shp).to(feat.device)
feat_src = feat_src * norm_src
if self._in_feats > self._out_feats: if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
feat = th.matmul(feat, self.weight) feat_src = th.matmul(feat_src, self.weight)
rst = adj @ feat rst = adj @ feat_src
else: else:
# aggregate first then mult W # aggregate first then mult W
rst = adj @ feat rst = adj @ feat_src
rst = th.matmul(rst, self.weight) rst = th.matmul(rst, self.weight)
if self._norm: if self._norm != 'none':
rst = rst * norm if self._norm == 'both':
norm_dst = th.pow(dst_degrees, -0.5)
else: # right
norm_dst = 1.0 / dst_degrees
shp = norm_dst.shape + (1,) * (feat.dim() - 1)
norm_dst = th.reshape(norm_dst, shp).to(feat.device)
rst = rst * norm_dst
if self.bias is not None: if self.bias is not None:
rst = rst + self.bias rst = rst + self.bias
......
"""Torch Module for DenseSAGEConv""" """Torch Module for DenseSAGEConv"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
from torch import nn from torch import nn
from ....utils import check_eq_shape
class DenseSAGEConv(nn.Module): class DenseSAGEConv(nn.Module):
...@@ -57,12 +58,17 @@ class DenseSAGEConv(nn.Module): ...@@ -57,12 +58,17 @@ class DenseSAGEConv(nn.Module):
Parameters Parameters
---------- ----------
adj : torch.Tensor adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on, The adjacency matrix of the graph to apply SAGE Convolution on, when
should be of shape :math:`(N, N)`, where a row represents the destination applied to a unidirectional bipartite graph, ``adj`` should be of shape
and a column represents the source. should be of shape :math:`(N_{out}, N_{in})`; when applied to a homo
feat : torch.Tensor graph, ``adj`` should be of shape :math:`(N, N)`. In both cases,
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` a row represents a destination node while a column represents a source
is size of input feature, :math:`N` is the number of nodes. node.
feat : torch.Tensor or a pair of torch.Tensor
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
Returns Returns
------- -------
...@@ -70,10 +76,15 @@ class DenseSAGEConv(nn.Module): ...@@ -70,10 +76,15 @@ class DenseSAGEConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
adj = adj.float().to(feat.device) check_eq_shape(feat)
feat = self.feat_drop(feat) if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
adj = adj.float().to(feat_src.device)
in_degrees = adj.sum(dim=1, keepdim=True) in_degrees = adj.sum(dim=1, keepdim=True)
h_neigh = (adj @ feat + feat) / (in_degrees + 1) h_neigh = (adj @ feat_src + feat_dst) / (in_degrees + 1)
rst = self.fc(h_neigh) rst = self.fc(h_neigh)
# activation # activation
if self.activation is not None: if self.activation is not None:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from torch import nn from torch import nn
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair
class EdgeConv(nn.Module): class EdgeConv(nn.Module):
...@@ -53,16 +54,22 @@ class EdgeConv(nn.Module): ...@@ -53,16 +54,22 @@ class EdgeConv(nn.Module):
---------- ----------
g : DGLGraph g : DGLGraph
The graph. The graph.
h : Tensor h : Tensor or pair of tensors
:math:`(N, D)` where :math:`N` is the number of nodes and :math:`(N, D)` where :math:`N` is the number of nodes and
:math:`D` is the number of feature dimensions. :math:`D` is the number of feature dimensions.
If a pair of tensors is given, the graph must be a uni-bipartite graph
with only one edge type, and the two tensors must have the same
dimensionality on all except the first axis.
Returns Returns
------- -------
torch.Tensor torch.Tensor
New node features. New node features.
""" """
with g.local_scope(): with g.local_scope():
g.ndata['x'] = h h_src, h_dst = expand_as_pair(h)
g.srcdata['x'] = h_src
g.dstdata['x'] = h_dst
if not self.batch_norm: if not self.batch_norm:
g.update_all(self.message, fn.max('e', 'x')) g.update_all(self.message, fn.max('e', 'x'))
else: else:
...@@ -88,4 +95,4 @@ class EdgeConv(nn.Module): ...@@ -88,4 +95,4 @@ class EdgeConv(nn.Module):
# images. # images.
g.edata['e'] = self.bn(g.edata['e']) g.edata['e'] = self.bn(g.edata['e'])
g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x')) g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x'))
return g.ndata['x'] return g.dstdata['x']
...@@ -6,6 +6,7 @@ from torch import nn ...@@ -6,6 +6,7 @@ from torch import nn
from .... import function as fn from .... import function as fn
from ..softmax import edge_softmax from ..softmax import edge_softmax
from ..utils import Identity from ..utils import Identity
from ....utils import expand_as_pair
# pylint: enable=W0235 # pylint: enable=W0235
class GATConv(nn.Module): class GATConv(nn.Module):
...@@ -25,8 +26,13 @@ class GATConv(nn.Module): ...@@ -25,8 +26,13 @@ class GATConv(nn.Module):
Parameters Parameters
---------- ----------
in_feats : int in_feats : int, or pair of ints
Input feature size. Input feature size.
If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
out_feats : int out_feats : int
Output feature size. Output feature size.
num_heads : int num_heads : int
...@@ -54,17 +60,25 @@ class GATConv(nn.Module): ...@@ -54,17 +60,25 @@ class GATConv(nn.Module):
activation=None): activation=None):
super(GATConv, self).__init__() super(GATConv, self).__init__()
self._num_heads = num_heads self._num_heads = num_heads
self._in_feats = in_feats self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
self.fc = nn.Linear(in_feats, out_feats * num_heads, bias=False) if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=False)
else:
self.fc = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)
self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope) self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual: if residual:
if in_feats != out_feats: if self._in_dst_feats != out_feats:
self.res_fc = nn.Linear(in_feats, num_heads * out_feats, bias=False) self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=False)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
...@@ -75,7 +89,11 @@ class GATConv(nn.Module): ...@@ -75,7 +89,11 @@ class GATConv(nn.Module):
def reset_parameters(self): def reset_parameters(self):
"""Reinitialize learnable parameters.""" """Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.fc.weight, gain=gain) if hasattr(self, 'fc'):
nn.init.xavier_normal_(self.fc.weight, gain=gain)
else: # bipartite graph neural networks
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain) nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain) nn.init.xavier_normal_(self.attn_r, gain=gain)
if isinstance(self.res_fc, nn.Linear): if isinstance(self.res_fc, nn.Linear):
...@@ -88,9 +106,11 @@ class GATConv(nn.Module): ...@@ -88,9 +106,11 @@ class GATConv(nn.Module):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor or pair of torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
is size of input feature, :math:`N` is the number of nodes. :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
Returns Returns
------- -------
...@@ -99,8 +119,15 @@ class GATConv(nn.Module): ...@@ -99,8 +119,15 @@ class GATConv(nn.Module):
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
""" """
graph = graph.local_var() graph = graph.local_var()
h = self.feat_drop(feat) if isinstance(feat, tuple):
feat = self.fc(h).view(-1, self._num_heads, self._out_feats) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view(
-1, self._num_heads, self._out_feats)
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
...@@ -111,9 +138,10 @@ class GATConv(nn.Module): ...@@ -111,9 +138,10 @@ class GATConv(nn.Module):
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v, # addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
el = (feat * self.attn_l).sum(dim=-1).unsqueeze(-1) el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat * self.attn_r).sum(dim=-1).unsqueeze(-1) er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.ndata.update({'ft': feat, 'el': el, 'er': er}) graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e')) e = self.leaky_relu(graph.edata.pop('e'))
...@@ -122,10 +150,10 @@ class GATConv(nn.Module): ...@@ -122,10 +150,10 @@ class GATConv(nn.Module):
# message passing # message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft')) fn.sum('m', 'ft'))
rst = graph.ndata['ft'] rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = self.res_fc(h).view(h.shape[0], -1, self._out_feats) resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment