Unverified Commit 63ac788f authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

auto-reformat-nn (#5319)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 0b3a447b
...@@ -6,8 +6,7 @@ import mxnet as mx ...@@ -6,8 +6,7 @@ import mxnet as mx
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
from .... import broadcast_nodes from .... import broadcast_nodes, function as fn
from .... import function as fn
from ....base import dgl_warning from ....base import dgl_warning
......
...@@ -98,20 +98,21 @@ class EdgeConv(nn.Block): ...@@ -98,20 +98,21 @@ class EdgeConv(nn.Block):
[-1.015364 0.78919804]] [-1.015364 0.78919804]]
<NDArray 4x2 @cpu(0)> <NDArray 4x2 @cpu(0)>
""" """
def __init__(self,
in_feat, def __init__(
out_feat, self, in_feat, out_feat, batch_norm=False, allow_zero_in_degree=False
batch_norm=False, ):
allow_zero_in_degree=False):
super(EdgeConv, self).__init__() super(EdgeConv, self).__init__()
self.batch_norm = batch_norm self.batch_norm = batch_norm
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
with self.name_scope(): with self.name_scope():
self.theta = nn.Dense(out_feat, in_units=in_feat, self.theta = nn.Dense(
weight_initializer=mx.init.Xavier()) out_feat, in_units=in_feat, weight_initializer=mx.init.Xavier()
self.phi = nn.Dense(out_feat, in_units=in_feat, )
weight_initializer=mx.init.Xavier()) self.phi = nn.Dense(
out_feat, in_units=in_feat, weight_initializer=mx.init.Xavier()
)
if batch_norm: if batch_norm:
self.bn = nn.BatchNorm(in_channels=out_feat) self.bn = nn.BatchNorm(in_channels=out_feat)
...@@ -164,26 +165,28 @@ class EdgeConv(nn.Block): ...@@ -164,26 +165,28 @@ class EdgeConv(nn.Block):
with g.local_scope(): with g.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if g.in_degrees().min() == 0: if g.in_degrees().min() == 0:
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
h_src, h_dst = expand_as_pair(h, g) h_src, h_dst = expand_as_pair(h, g)
g.srcdata['x'] = h_src g.srcdata["x"] = h_src
g.dstdata['x'] = h_dst g.dstdata["x"] = h_dst
g.apply_edges(fn.v_sub_u('x', 'x', 'theta')) g.apply_edges(fn.v_sub_u("x", "x", "theta"))
g.edata['theta'] = self.theta(g.edata['theta']) g.edata["theta"] = self.theta(g.edata["theta"])
g.dstdata['phi'] = self.phi(g.dstdata['x']) g.dstdata["phi"] = self.phi(g.dstdata["x"])
if not self.batch_norm: if not self.batch_norm:
g.update_all(fn.e_add_v('theta', 'phi', 'e'), fn.max('e', 'x')) g.update_all(fn.e_add_v("theta", "phi", "e"), fn.max("e", "x"))
else: else:
g.apply_edges(fn.e_add_v('theta', 'phi', 'e')) g.apply_edges(fn.e_add_v("theta", "phi", "e"))
g.edata['e'] = self.bn(g.edata['e']) g.edata["e"] = self.bn(g.edata["e"])
g.update_all(fn.copy_e('e', 'm'), fn.max('m', 'x')) g.update_all(fn.copy_e("e", "m"), fn.max("m", "x"))
return g.dstdata['x'] return g.dstdata["x"]
"""MXNet modules for graph attention networks(GAT).""" """MXNet modules for graph attention networks(GAT)."""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import math import math
import mxnet as mx import mxnet as mx
from mxnet.gluon import nn from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import Identity from mxnet.gluon.contrib.nn import Identity
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from ...functional import edge_softmax
from ....utils import expand_as_pair from ....utils import expand_as_pair
from ...functional import edge_softmax
#pylint: enable=W0235
# pylint: enable=W0235
class GATConv(nn.Block): class GATConv(nn.Block):
r"""Graph attention layer from `Graph Attention Network r"""Graph attention layer from `Graph Attention Network
<https://arxiv.org/pdf/1710.10903.pdf>`__ <https://arxiv.org/pdf/1710.10903.pdf>`__
...@@ -134,16 +136,19 @@ class GATConv(nn.Block): ...@@ -134,16 +136,19 @@ class GATConv(nn.Block):
[-1.9325689 1.3824553 ]]] [-1.9325689 1.3824553 ]]]
<NDArray 4x3x2 @cpu(0)> <NDArray 4x3x2 @cpu(0)>
""" """
def __init__(self,
def __init__(
self,
in_feats, in_feats,
out_feats, out_feats,
num_heads, num_heads,
feat_drop=0., feat_drop=0.0,
attn_drop=0., attn_drop=0.0,
negative_slope=0.2, negative_slope=0.2,
residual=False, residual=False,
activation=None, activation=None,
allow_zero_in_degree=False): allow_zero_in_degree=False,
):
super(GATConv, self).__init__() super(GATConv, self).__init__()
self._num_heads = num_heads self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
...@@ -152,31 +157,48 @@ class GATConv(nn.Block): ...@@ -152,31 +157,48 @@ class GATConv(nn.Block):
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
with self.name_scope(): with self.name_scope():
if isinstance(in_feats, tuple): if isinstance(in_feats, tuple):
self.fc_src = nn.Dense(out_feats * num_heads, use_bias=False, self.fc_src = nn.Dense(
out_feats * num_heads,
use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_src_feats) in_units=self._in_src_feats,
self.fc_dst = nn.Dense(out_feats * num_heads, use_bias=False, )
self.fc_dst = nn.Dense(
out_feats * num_heads,
use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_dst_feats) in_units=self._in_dst_feats,
)
else: else:
self.fc = nn.Dense(out_feats * num_heads, use_bias=False, self.fc = nn.Dense(
out_feats * num_heads,
use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats) in_units=in_feats,
self.attn_l = self.params.get('attn_l', )
self.attn_l = self.params.get(
"attn_l",
shape=(1, num_heads, out_feats), shape=(1, num_heads, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0))) init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
self.attn_r = self.params.get('attn_r', )
self.attn_r = self.params.get(
"attn_r",
shape=(1, num_heads, out_feats), shape=(1, num_heads, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0))) init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope) self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual: if residual:
if in_feats != out_feats: if in_feats != out_feats:
self.res_fc = nn.Dense(out_feats * num_heads, use_bias=False, self.res_fc = nn.Dense(
out_feats * num_heads,
use_bias=False,
weight_initializer=mx.init.Xavier( weight_initializer=mx.init.Xavier(
magnitude=math.sqrt(2.0)), magnitude=math.sqrt(2.0)
in_units=in_feats) ),
in_units=in_feats,
)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
...@@ -235,15 +257,17 @@ class GATConv(nn.Block): ...@@ -235,15 +257,17 @@ class GATConv(nn.Block):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if graph.in_degrees().min() == 0: if graph.in_degrees().min() == 0:
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple): if isinstance(feat, tuple):
src_prefix_shape = feat[0].shape[:-1] src_prefix_shape = feat[0].shape[:-1]
...@@ -251,22 +275,27 @@ class GATConv(nn.Block): ...@@ -251,22 +275,27 @@ class GATConv(nn.Block):
feat_dim = feat[0].shape[-1] feat_dim = feat[0].shape[-1]
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'): if not hasattr(self, "fc_src"):
self.fc_src, self.fc_dst = self.fc, self.fc self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = self.fc_src(h_src.reshape(-1, feat_dim)).reshape( feat_src = self.fc_src(h_src.reshape(-1, feat_dim)).reshape(
*src_prefix_shape, self._num_heads, self._out_feats) *src_prefix_shape, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst.reshape(-1, feat_dim)).reshape( feat_dst = self.fc_dst(h_dst.reshape(-1, feat_dim)).reshape(
*dst_prefix_shape, self._num_heads, self._out_feats) *dst_prefix_shape, self._num_heads, self._out_feats
)
else: else:
src_prefix_shape = dst_prefix_shape = feat.shape[:-1] src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
feat_dim = feat[0].shape[-1] feat_dim = feat[0].shape[-1]
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src.reshape(-1, feat_dim)).reshape( feat_src = feat_dst = self.fc(
*src_prefix_shape, self._num_heads, self._out_feats) h_src.reshape(-1, feat_dim)
).reshape(*src_prefix_shape, self._num_heads, self._out_feats)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()] h_dst = h_dst[: graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:] dst_prefix_shape = (
graph.number_of_dst_nodes(),
) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
...@@ -277,28 +306,36 @@ class GATConv(nn.Block): ...@@ -277,28 +306,36 @@ class GATConv(nn.Block):
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v, # addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l.data(feat_src.context)).sum(axis=-1).expand_dims(-1) el = (
er = (feat_dst * self.attn_r.data(feat_src.context)).sum(axis=-1).expand_dims(-1) (feat_src * self.attn_l.data(feat_src.context))
graph.srcdata.update({'ft': feat_src, 'el': el}) .sum(axis=-1)
graph.dstdata.update({'er': er}) .expand_dims(-1)
)
er = (
(feat_dst * self.attn_r.data(feat_src.context))
.sum(axis=-1)
.expand_dims(-1)
)
graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(graph.edata.pop('e')) e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax # compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
fn.sum('m', 'ft')) rst = graph.dstdata["ft"]
rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = self.res_fc(h_dst.reshape(-1, feat_dim)).reshape( resval = self.res_fc(h_dst.reshape(-1, feat_dim)).reshape(
*dst_prefix_shape, -1, self._out_feats) *dst_prefix_shape, -1, self._out_feats
)
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
if get_attention: if get_attention:
return rst, graph.edata['a'] return rst, graph.edata["a"]
else: else:
return rst return rst
"""Torch Module for GMM Conv""" """Torch Module for GMM Conv"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import math import math
import mxnet as mx import mxnet as mx
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
...@@ -107,15 +108,18 @@ class GMMConv(nn.Block): ...@@ -107,15 +108,18 @@ class GMMConv(nn.Block):
[-0.1005067 -0.09494358]] [-0.1005067 -0.09494358]]
<NDArray 4x2 @cpu(0)> <NDArray 4x2 @cpu(0)>
""" """
def __init__(self,
def __init__(
self,
in_feats, in_feats,
out_feats, out_feats,
dim, dim,
n_kernels, n_kernels,
aggregator_type='sum', aggregator_type="sum",
residual=False, residual=False,
bias=True, bias=True,
allow_zero_in_degree=False): allow_zero_in_degree=False,
):
super(GMMConv, self).__init__() super(GMMConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
...@@ -123,38 +127,44 @@ class GMMConv(nn.Block): ...@@ -123,38 +127,44 @@ class GMMConv(nn.Block):
self._dim = dim self._dim = dim
self._n_kernels = n_kernels self._n_kernels = n_kernels
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
if aggregator_type == 'sum': if aggregator_type == "sum":
self._reducer = fn.sum self._reducer = fn.sum
elif aggregator_type == 'mean': elif aggregator_type == "mean":
self._reducer = fn.mean self._reducer = fn.mean
elif aggregator_type == 'max': elif aggregator_type == "max":
self._reducer = fn.max self._reducer = fn.max
else: else:
raise KeyError("Aggregator type {} not recognized.".format(aggregator_type)) raise KeyError(
"Aggregator type {} not recognized.".format(aggregator_type)
)
with self.name_scope(): with self.name_scope():
self.mu = self.params.get('mu', self.mu = self.params.get(
shape=(n_kernels, dim), "mu", shape=(n_kernels, dim), init=mx.init.Normal(0.1)
init=mx.init.Normal(0.1)) )
self.inv_sigma = self.params.get('inv_sigma', self.inv_sigma = self.params.get(
shape=(n_kernels, dim), "inv_sigma", shape=(n_kernels, dim), init=mx.init.Constant(1)
init=mx.init.Constant(1)) )
self.fc = nn.Dense(n_kernels * out_feats, self.fc = nn.Dense(
n_kernels * out_feats,
in_units=self._in_src_feats, in_units=self._in_src_feats,
use_bias=False, use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0))) weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
if residual: if residual:
if self._in_dst_feats != out_feats: if self._in_dst_feats != out_feats:
self.res_fc = nn.Dense(out_feats, in_units=self._in_dst_feats, use_bias=False) self.res_fc = nn.Dense(
out_feats, in_units=self._in_dst_feats, use_bias=False
)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
self.res_fc = None self.res_fc = None
if bias: if bias:
self.bias = self.params.get('bias', self.bias = self.params.get(
shape=(out_feats,), "bias", shape=(out_feats,), init=mx.init.Zero()
init=mx.init.Zero()) )
else: else:
self.bias = None self.bias = None
...@@ -208,32 +218,44 @@ class GMMConv(nn.Block): ...@@ -208,32 +218,44 @@ class GMMConv(nn.Block):
""" """
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if graph.in_degrees().min() == 0: if graph.in_degrees().min() == 0:
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
with graph.local_scope(): with graph.local_scope():
graph.srcdata['h'] = self.fc(feat_src).reshape( graph.srcdata["h"] = self.fc(feat_src).reshape(
-1, self._n_kernels, self._out_feats) -1, self._n_kernels, self._out_feats
)
E = graph.number_of_edges() E = graph.number_of_edges()
# compute gaussian weight # compute gaussian weight
gaussian = -0.5 * ((pseudo.reshape(E, 1, self._dim) - gaussian = -0.5 * (
self.mu.data(feat_src.context) (
.reshape(1, self._n_kernels, self._dim)) ** 2) pseudo.reshape(E, 1, self._dim)
gaussian = gaussian *\ - self.mu.data(feat_src.context).reshape(
(self.inv_sigma.data(feat_src.context) 1, self._n_kernels, self._dim
.reshape(1, self._n_kernels, self._dim) ** 2) )
)
** 2
)
gaussian = gaussian * (
self.inv_sigma.data(feat_src.context).reshape(
1, self._n_kernels, self._dim
)
** 2
)
gaussian = nd.exp(gaussian.sum(axis=-1, keepdims=True)) # (E, K, 1) gaussian = nd.exp(gaussian.sum(axis=-1, keepdims=True)) # (E, K, 1)
graph.edata['w'] = gaussian graph.edata["w"] = gaussian
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h')) graph.update_all(fn.u_mul_e("h", "w", "m"), self._reducer("m", "h"))
rst = graph.dstdata['h'].sum(1) rst = graph.dstdata["h"].sum(1)
# residual connection # residual connection
if self.res_fc is not None: if self.res_fc is not None:
rst = rst + self.res_fc(feat_dst) rst = rst + self.res_fc(feat_dst)
......
...@@ -9,6 +9,7 @@ from .... import function as fn ...@@ -9,6 +9,7 @@ from .... import function as fn
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair from ....utils import expand_as_pair
class GraphConv(gluon.Block): class GraphConv(gluon.Block):
r"""Graph convolutional layer from `Semi-Supervised Classification with Graph Convolutional r"""Graph convolutional layer from `Semi-Supervised Classification with Graph Convolutional
Networks <https://arxiv.org/abs/1609.02907>`__ Networks <https://arxiv.org/abs/1609.02907>`__
...@@ -133,18 +134,23 @@ class GraphConv(gluon.Block): ...@@ -133,18 +134,23 @@ class GraphConv(gluon.Block):
[ 0.26967263 0.308129 ]] [ 0.26967263 0.308129 ]]
<NDArray 4x2 @cpu(0)> <NDArray 4x2 @cpu(0)>
""" """
def __init__(self,
def __init__(
self,
in_feats, in_feats,
out_feats, out_feats,
norm='both', norm="both",
weight=True, weight=True,
bias=True, bias=True,
activation=None, activation=None,
allow_zero_in_degree=False): allow_zero_in_degree=False,
):
super(GraphConv, self).__init__() super(GraphConv, self).__init__()
if norm not in ('none', 'both', 'right', 'left'): if norm not in ("none", "both", "right", "left"):
raise DGLError('Invalid norm value. Must be either "none", "both", "right" or "left".' raise DGLError(
' But got "{}".'.format(norm)) 'Invalid norm value. Must be either "none", "both", "right" or "left".'
' But got "{}".'.format(norm)
)
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._norm = norm self._norm = norm
...@@ -152,14 +158,18 @@ class GraphConv(gluon.Block): ...@@ -152,14 +158,18 @@ class GraphConv(gluon.Block):
with self.name_scope(): with self.name_scope():
if weight: if weight:
self.weight = self.params.get('weight', shape=(in_feats, out_feats), self.weight = self.params.get(
init=mx.init.Xavier(magnitude=math.sqrt(2.0))) "weight",
shape=(in_feats, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
else: else:
self.weight = None self.weight = None
if bias: if bias:
self.bias = self.params.get('bias', shape=(out_feats,), self.bias = self.params.get(
init=mx.init.Zero()) "bias", shape=(out_feats,), init=mx.init.Zero()
)
else: else:
self.bias = None self.bias = None
...@@ -225,21 +235,27 @@ class GraphConv(gluon.Block): ...@@ -225,21 +235,27 @@ class GraphConv(gluon.Block):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if graph.in_degrees().min() == 0: if graph.in_degrees().min() == 0:
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
if self._norm in ['both', 'left']: if self._norm in ["both", "left"]:
degs = graph.out_degrees().as_in_context(feat_dst.context).astype('float32') degs = (
graph.out_degrees()
.as_in_context(feat_dst.context)
.astype("float32")
)
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf")) degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
if self._norm == 'both': if self._norm == "both":
norm = mx.nd.power(degs, -0.5) norm = mx.nd.power(degs, -0.5)
else: else:
norm = 1.0 / degs norm = 1.0 / degs
...@@ -247,12 +263,13 @@ class GraphConv(gluon.Block): ...@@ -247,12 +263,13 @@ class GraphConv(gluon.Block):
norm = norm.reshape(shp) norm = norm.reshape(shp)
feat_src = feat_src * norm feat_src = feat_src * norm
if weight is not None: if weight is not None:
if self.weight is not None: if self.weight is not None:
raise DGLError('External weight is provided while at the same time the' raise DGLError(
' module has defined its own weight parameter. Please' "External weight is provided while at the same time the"
' create the module with flag weight=False.') " module has defined its own weight parameter. Please"
" create the module with flag weight=False."
)
else: else:
weight = self.weight.data(feat_src.context) weight = self.weight.data(feat_src.context)
...@@ -260,23 +277,29 @@ class GraphConv(gluon.Block): ...@@ -260,23 +277,29 @@ class GraphConv(gluon.Block):
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
if weight is not None: if weight is not None:
feat_src = mx.nd.dot(feat_src, weight) feat_src = mx.nd.dot(feat_src, weight)
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.update_all(fn.copy_u(u='h', out='m'), graph.update_all(
fn.sum(msg='m', out='h')) fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")
rst = graph.dstdata.pop('h') )
rst = graph.dstdata.pop("h")
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.update_all(fn.copy_u(u='h', out='m'), graph.update_all(
fn.sum(msg='m', out='h')) fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")
rst = graph.dstdata.pop('h') )
rst = graph.dstdata.pop("h")
if weight is not None: if weight is not None:
rst = mx.nd.dot(rst, weight) rst = mx.nd.dot(rst, weight)
if self._norm in ['both', 'right']: if self._norm in ["both", "right"]:
degs = graph.in_degrees().as_in_context(feat_dst.context).astype('float32') degs = (
graph.in_degrees()
.as_in_context(feat_dst.context)
.astype("float32")
)
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf")) degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
if self._norm == 'both': if self._norm == "both":
norm = mx.nd.power(degs, -0.5) norm = mx.nd.power(degs, -0.5)
else: else:
norm = 1.0 / degs norm = 1.0 / degs
...@@ -293,9 +316,9 @@ class GraphConv(gluon.Block): ...@@ -293,9 +316,9 @@ class GraphConv(gluon.Block):
return rst return rst
def __repr__(self): def __repr__(self):
summary = 'GraphConv(' summary = "GraphConv("
summary += 'in={:d}, out={:d}, normalization={}, activation={}'.format( summary += "in={:d}, out={:d}, normalization={}, activation={}".format(
self._in_feats, self._out_feats, self._in_feats, self._out_feats, self._norm, self._activation
self._norm, self._activation) )
summary += ')' summary += ")"
return summary return summary
...@@ -89,24 +89,29 @@ class NNConv(nn.Block): ...@@ -89,24 +89,29 @@ class NNConv(nn.Block):
[ 0.24425688 0.3238042 ]] [ 0.24425688 0.3238042 ]]
<NDArray 4x2 @cpu(0)> <NDArray 4x2 @cpu(0)>
""" """
def __init__(self,
def __init__(
self,
in_feats, in_feats,
out_feats, out_feats,
edge_func, edge_func,
aggregator_type, aggregator_type,
residual=False, residual=False,
bias=True): bias=True,
):
super(NNConv, self).__init__() super(NNConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
if aggregator_type == 'sum': if aggregator_type == "sum":
self.reducer = fn.sum self.reducer = fn.sum
elif aggregator_type == 'mean': elif aggregator_type == "mean":
self.reducer = fn.mean self.reducer = fn.mean
elif aggregator_type == 'max': elif aggregator_type == "max":
self.reducer = fn.max self.reducer = fn.max
else: else:
raise KeyError('Aggregator type {} not recognized: '.format(aggregator_type)) raise KeyError(
"Aggregator type {} not recognized: ".format(aggregator_type)
)
self._aggre_type = aggregator_type self._aggre_type = aggregator_type
with self.name_scope(): with self.name_scope():
...@@ -114,17 +119,20 @@ class NNConv(nn.Block): ...@@ -114,17 +119,20 @@ class NNConv(nn.Block):
if residual: if residual:
if self._in_dst_feats != out_feats: if self._in_dst_feats != out_feats:
self.res_fc = nn.Dense( self.res_fc = nn.Dense(
out_feats, in_units=self._in_dst_feats, out_feats,
use_bias=False, weight_initializer=mx.init.Xavier()) in_units=self._in_dst_feats,
use_bias=False,
weight_initializer=mx.init.Xavier(),
)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
self.res_fc = None self.res_fc = None
if bias: if bias:
self.bias = self.params.get('bias', self.bias = self.params.get(
shape=(out_feats,), "bias", shape=(out_feats,), init=mx.init.Zero()
init=mx.init.Zero()) )
else: else:
self.bias = None self.bias = None
...@@ -153,12 +161,16 @@ class NNConv(nn.Block): ...@@ -153,12 +161,16 @@ class NNConv(nn.Block):
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
# (n, d_in, 1) # (n, d_in, 1)
graph.srcdata['h'] = feat_src.expand_dims(-1) graph.srcdata["h"] = feat_src.expand_dims(-1)
# (n, d_in, d_out) # (n, d_in, d_out)
graph.edata['w'] = self.edge_nn(efeat).reshape(-1, self._in_src_feats, self._out_feats) graph.edata["w"] = self.edge_nn(efeat).reshape(
-1, self._in_src_feats, self._out_feats
)
# (n, d_in, d_out) # (n, d_in, d_out)
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh')) graph.update_all(
rst = graph.dstdata.pop('neigh').sum(axis=1) # (n, d_out) fn.u_mul_e("h", "w", "m"), self.reducer("m", "neigh")
)
rst = graph.dstdata.pop("neigh").sum(axis=1) # (n, d_out)
# residual connection # residual connection
if self.res_fc is not None: if self.res_fc is not None:
rst = rst + self.res_fc(feat_dst) rst = rst + self.res_fc(feat_dst)
......
"""MXNet Module for GraphSAGE layer""" """MXNet Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import math import math
import mxnet as mx import mxnet as mx
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape from ....utils import check_eq_shape, expand_as_pair
class SAGEConv(nn.Block): class SAGEConv(nn.Block):
r"""GraphSAGE layer from `Inductive Representation Learning on r"""GraphSAGE layer from `Inductive Representation Learning on
...@@ -89,20 +91,25 @@ class SAGEConv(nn.Block): ...@@ -89,20 +91,25 @@ class SAGEConv(nn.Block):
[-1.0509381 2.2239418 ]] [-1.0509381 2.2239418 ]]
<NDArray 4x2 @cpu(0)> <NDArray 4x2 @cpu(0)>
""" """
def __init__(self,
def __init__(
self,
in_feats, in_feats,
out_feats, out_feats,
aggregator_type='mean', aggregator_type="mean",
feat_drop=0., feat_drop=0.0,
bias=True, bias=True,
norm=None, norm=None,
activation=None): activation=None,
):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'} valid_aggre_types = {"mean", "gcn", "pool", "lstm"}
if aggregator_type not in valid_aggre_types: if aggregator_type not in valid_aggre_types:
raise DGLError( raise DGLError(
'Invalid aggregator_type. Must be one of {}. ' "Invalid aggregator_type. Must be one of {}. "
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type) "But got {!r} instead.".format(
valid_aggre_types, aggregator_type
)
) )
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
...@@ -112,19 +119,28 @@ class SAGEConv(nn.Block): ...@@ -112,19 +119,28 @@ class SAGEConv(nn.Block):
self.norm = norm self.norm = norm
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation self.activation = activation
if aggregator_type == 'pool': if aggregator_type == "pool":
self.fc_pool = nn.Dense(self._in_src_feats, use_bias=bias, self.fc_pool = nn.Dense(
self._in_src_feats,
use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_src_feats) in_units=self._in_src_feats,
if aggregator_type == 'lstm': )
if aggregator_type == "lstm":
raise NotImplementedError raise NotImplementedError
if aggregator_type != 'gcn': if aggregator_type != "gcn":
self.fc_self = nn.Dense(out_feats, use_bias=bias, self.fc_self = nn.Dense(
out_feats,
use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_dst_feats) in_units=self._in_dst_feats,
self.fc_neigh = nn.Dense(out_feats, use_bias=bias, )
self.fc_neigh = nn.Dense(
out_feats,
use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_src_feats) in_units=self._in_src_feats,
)
def forward(self, graph, feat): def forward(self, graph, feat):
r"""Compute GraphSAGE layer. r"""Compute GraphSAGE layer.
...@@ -153,39 +169,47 @@ class SAGEConv(nn.Block): ...@@ -153,39 +169,47 @@ class SAGEConv(nn.Block):
else: else:
feat_src = feat_dst = self.feat_drop(feat) feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_self = feat_dst h_self = feat_dst
# Handle the case of graphs without edges # Handle the case of graphs without edges
if graph.number_of_edges() == 0: if graph.number_of_edges() == 0:
dst_neigh = mx.nd.zeros((graph.number_of_dst_nodes(), self._in_src_feats)) dst_neigh = mx.nd.zeros(
(graph.number_of_dst_nodes(), self._in_src_feats)
)
dst_neigh = dst_neigh.as_in_context(feat_dst.context) dst_neigh = dst_neigh.as_in_context(feat_dst.context)
graph.dstdata['neigh'] = dst_neigh graph.dstdata["neigh"] = dst_neigh
if self._aggre_type == 'mean': if self._aggre_type == "mean":
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) graph.update_all(fn.copy_u("h", "m"), fn.mean("m", "neigh"))
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata["neigh"]
elif self._aggre_type == 'gcn': elif self._aggre_type == "gcn":
check_eq_shape(feat) check_eq_shape(feat)
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous graph.dstdata["h"] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "neigh"))
# divide in degrees # divide in degrees
degs = graph.in_degrees().astype(feat_dst.dtype) degs = graph.in_degrees().astype(feat_dst.dtype)
degs = degs.as_in_context(feat_dst.context) degs = degs.as_in_context(feat_dst.context)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.expand_dims(-1) + 1) h_neigh = (graph.dstdata["neigh"] + graph.dstdata["h"]) / (
elif self._aggre_type == 'pool': degs.expand_dims(-1) + 1
graph.srcdata['h'] = nd.relu(self.fc_pool(feat_src)) )
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh')) elif self._aggre_type == "pool":
h_neigh = graph.dstdata['neigh'] graph.srcdata["h"] = nd.relu(self.fc_pool(feat_src))
elif self._aggre_type == 'lstm': graph.update_all(fn.copy_u("h", "m"), fn.max("m", "neigh"))
h_neigh = graph.dstdata["neigh"]
elif self._aggre_type == "lstm":
raise NotImplementedError raise NotImplementedError
else: else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) raise KeyError(
"Aggregator type {} not recognized.".format(
self._aggre_type
)
)
if self._aggre_type == 'gcn': if self._aggre_type == "gcn":
rst = self.fc_neigh(h_neigh) rst = self.fc_neigh(h_neigh)
else: else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
......
...@@ -60,12 +60,8 @@ class TAGConv(gluon.Block): ...@@ -60,12 +60,8 @@ class TAGConv(gluon.Block):
[ 0.32964635 -0.7669234 ]] [ 0.32964635 -0.7669234 ]]
<NDArray 6x2 @cpu(0)> <NDArray 6x2 @cpu(0)>
""" """
def __init__(self,
in_feats, def __init__(self, in_feats, out_feats, k=2, bias=True, activation=None):
out_feats,
k=2,
bias=True,
activation=None):
super(TAGConv, self).__init__() super(TAGConv, self).__init__()
self.out_feats = out_feats self.out_feats = out_feats
self.k = k self.k = k
...@@ -74,11 +70,14 @@ class TAGConv(gluon.Block): ...@@ -74,11 +70,14 @@ class TAGConv(gluon.Block):
self.in_feats = in_feats self.in_feats = in_feats
self.lin = self.params.get( self.lin = self.params.get(
'weight', shape=(self.in_feats * (self.k + 1), self.out_feats), "weight",
init=mx.init.Xavier(magnitude=math.sqrt(2.0))) shape=(self.in_feats * (self.k + 1), self.out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
if self.bias: if self.bias:
self.h_bias = self.params.get('bias', shape=(out_feats,), self.h_bias = self.params.get(
init=mx.init.Zero()) "bias", shape=(out_feats,), init=mx.init.Zero()
)
def forward(self, graph, feat): def forward(self, graph, feat):
r""" r"""
...@@ -102,21 +101,24 @@ class TAGConv(gluon.Block): ...@@ -102,21 +101,24 @@ class TAGConv(gluon.Block):
is size of output feature. is size of output feature.
""" """
with graph.local_scope(): with graph.local_scope():
assert graph.is_homogeneous, 'Graph is not homogeneous' assert graph.is_homogeneous, "Graph is not homogeneous"
degs = graph.in_degrees().astype('float32') degs = graph.in_degrees().astype("float32")
norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5) norm = mx.nd.power(
mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5
)
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context) norm = norm.reshape(shp).as_in_context(feat.context)
rst = feat rst = feat
for _ in range(self.k): for _ in range(self.k):
rst = rst * norm rst = rst * norm
graph.ndata['h'] = rst graph.ndata["h"] = rst
graph.update_all(fn.copy_u(u='h', out='m'), graph.update_all(
fn.sum(msg='m', out='h')) fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")
rst = graph.ndata['h'] )
rst = graph.ndata["h"]
rst = rst * norm rst = rst * norm
feat = mx.nd.concat(feat, rst, dim=-1) feat = mx.nd.concat(feat, rst, dim=-1)
......
...@@ -7,7 +7,13 @@ from .glob import * ...@@ -7,7 +7,13 @@ from .glob import *
from .softmax import * from .softmax import *
from .factory import * from .factory import *
from .hetero import * from .hetero import *
from .utils import Sequential, WeightBasis, JumpingKnowledge, LabelPropagation, LaplacianPosEnc
from .sparse_emb import NodeEmbedding from .sparse_emb import NodeEmbedding
from .utils import (
JumpingKnowledge,
LabelPropagation,
LaplacianPosEnc,
Sequential,
WeightBasis,
)
from .network_emb import * from .network_emb import *
from .graph_transformer import * from .graph_transformer import *
...@@ -4,8 +4,7 @@ import torch as th ...@@ -4,8 +4,7 @@ import torch as th
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from .... import broadcast_nodes from .... import broadcast_nodes, function as fn
from .... import function as fn
from ....base import dgl_warning from ....base import dgl_warning
...@@ -101,10 +100,9 @@ class ChebConv(nn.Module): ...@@ -101,10 +100,9 @@ class ChebConv(nn.Module):
return graph.ndata.pop("h") * D_invsqrt return graph.ndata.pop("h") * D_invsqrt
with graph.local_scope(): with graph.local_scope():
D_invsqrt = ( D_invsqrt = th.pow(
th.pow(graph.in_degrees().to(feat).clamp(min=1), -0.5) graph.in_degrees().to(feat).clamp(min=1), -0.5
.unsqueeze(-1) ).unsqueeze(-1)
)
if lambda_max is None: if lambda_max is None:
dgl_warning( dgl_warning(
......
...@@ -5,7 +5,7 @@ from functools import partial ...@@ -5,7 +5,7 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from .pnaconv import AGGREGATORS, SCALERS, PNAConv, PNAConvTower from .pnaconv import AGGREGATORS, PNAConv, PNAConvTower, SCALERS
def aggregate_dir_av(h, eig_s, eig_d, eig_idx): def aggregate_dir_av(h, eig_s, eig_d, eig_idx):
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
from torch import nn from torch import nn
from .... import function as fn from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair from ....utils import expand_as_pair
from ...functional import edge_softmax
class DotGatConv(nn.Module): class DotGatConv(nn.Module):
...@@ -118,11 +118,10 @@ class DotGatConv(nn.Module): ...@@ -118,11 +118,10 @@ class DotGatConv(nn.Module):
[-0.5945, -0.4801], [-0.5945, -0.4801],
[ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>) [ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>)
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self, in_feats, out_feats, num_heads, allow_zero_in_degree=False
num_heads, ):
allow_zero_in_degree=False):
super(DotGatConv, self).__init__() super(DotGatConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
...@@ -130,10 +129,22 @@ class DotGatConv(nn.Module): ...@@ -130,10 +129,22 @@ class DotGatConv(nn.Module):
self._num_heads = num_heads self._num_heads = num_heads
if isinstance(in_feats, tuple): if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False) self.fc_src = nn.Linear(
self.fc_dst = nn.Linear(self._in_dst_feats, self._out_feats*self._num_heads, bias=False) self._in_src_feats,
self._out_feats * self._num_heads,
bias=False,
)
self.fc_dst = nn.Linear(
self._in_dst_feats,
self._out_feats * self._num_heads,
bias=False,
)
else: else:
self.fc = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False) self.fc = nn.Linear(
self._in_src_feats,
self._out_feats * self._num_heads,
bias=False,
)
def forward(self, graph, feat, get_attention=False): def forward(self, graph, feat, get_attention=False):
r""" r"""
...@@ -175,45 +186,57 @@ class DotGatConv(nn.Module): ...@@ -175,45 +186,57 @@ class DotGatConv(nn.Module):
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any(): if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
# check if feat is a tuple # check if feat is a tuple
if isinstance(feat, tuple): if isinstance(feat, tuple):
h_src = feat[0] h_src = feat[0]
h_dst = feat[1] h_dst = feat[1]
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) feat_src = self.fc_src(h_src).view(
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst).view(
-1, self._num_heads, self._out_feats
)
else: else:
h_src = feat h_src = feat
feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats) feat_src = feat_dst = self.fc(h_src).view(
-1, self._num_heads, self._out_feats
)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[: graph.number_of_dst_nodes()]
# Assign features to nodes # Assign features to nodes
graph.srcdata.update({'ft': feat_src}) graph.srcdata.update({"ft": feat_src})
graph.dstdata.update({'ft': feat_dst}) graph.dstdata.update({"ft": feat_dst})
# Step 1. dot product # Step 1. dot product
graph.apply_edges(fn.u_dot_v('ft', 'ft', 'a')) graph.apply_edges(fn.u_dot_v("ft", "ft", "a"))
# Step 2. edge softmax to compute attention scores # Step 2. edge softmax to compute attention scores
graph.edata['sa'] = edge_softmax(graph, graph.edata['a'] / self._out_feats**0.5) graph.edata["sa"] = edge_softmax(
graph, graph.edata["a"] / self._out_feats**0.5
)
# Step 3. Broadcast softmax value to each edge, and aggregate dst node # Step 3. Broadcast softmax value to each edge, and aggregate dst node
graph.update_all(fn.u_mul_e('ft', 'sa', 'attn'), fn.sum('attn', 'agg_u')) graph.update_all(
fn.u_mul_e("ft", "sa", "attn"), fn.sum("attn", "agg_u")
)
# output results to the destination nodes # output results to the destination nodes
rst = graph.dstdata['agg_u'] rst = graph.dstdata["agg_u"]
if get_attention: if get_attention:
return rst, graph.edata['sa'] return rst, graph.edata["sa"]
else: else:
return rst return rst
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
from torch import nn from torch import nn
from ....base import DGLError
from .... import function as fn from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair from ....utils import expand_as_pair
...@@ -92,11 +93,10 @@ class EdgeConv(nn.Module): ...@@ -92,11 +93,10 @@ class EdgeConv(nn.Module):
[ 0.2101, 1.3466], [ 0.2101, 1.3466],
[ 0.2342, -0.9868]], grad_fn=<CopyReduceBackward>) [ 0.2342, -0.9868]], grad_fn=<CopyReduceBackward>)
""" """
def __init__(self,
in_feat, def __init__(
out_feat, self, in_feat, out_feat, batch_norm=False, allow_zero_in_degree=False
batch_norm=False, ):
allow_zero_in_degree=False):
super(EdgeConv, self).__init__() super(EdgeConv, self).__init__()
self.batch_norm = batch_norm self.batch_norm = batch_norm
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
...@@ -155,26 +155,28 @@ class EdgeConv(nn.Module): ...@@ -155,26 +155,28 @@ class EdgeConv(nn.Module):
with g.local_scope(): with g.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if (g.in_degrees() == 0).any(): if (g.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
h_src, h_dst = expand_as_pair(feat, g) h_src, h_dst = expand_as_pair(feat, g)
g.srcdata['x'] = h_src g.srcdata["x"] = h_src
g.dstdata['x'] = h_dst g.dstdata["x"] = h_dst
g.apply_edges(fn.v_sub_u('x', 'x', 'theta')) g.apply_edges(fn.v_sub_u("x", "x", "theta"))
g.edata['theta'] = self.theta(g.edata['theta']) g.edata["theta"] = self.theta(g.edata["theta"])
g.dstdata['phi'] = self.phi(g.dstdata['x']) g.dstdata["phi"] = self.phi(g.dstdata["x"])
if not self.batch_norm: if not self.batch_norm:
g.update_all(fn.e_add_v('theta', 'phi', 'e'), fn.max('e', 'x')) g.update_all(fn.e_add_v("theta", "phi", "e"), fn.max("e", "x"))
else: else:
g.apply_edges(fn.e_add_v('theta', 'phi', 'e')) g.apply_edges(fn.e_add_v("theta", "phi", "e"))
# Although the official implementation includes a per-edge # Although the official implementation includes a per-edge
# batch norm within EdgeConv, I choose to replace it with a # batch norm within EdgeConv, I choose to replace it with a
# global batch norm for a number of reasons: # global batch norm for a number of reasons:
...@@ -194,6 +196,6 @@ class EdgeConv(nn.Module): ...@@ -194,6 +196,6 @@ class EdgeConv(nn.Module):
# In this case, the learned statistics of each position # In this case, the learned statistics of each position
# by batch norm is not as meaningful as those learned from # by batch norm is not as meaningful as those learned from
# images. # images.
g.edata['e'] = self.bn(g.edata['e']) g.edata["e"] = self.bn(g.edata["e"])
g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x')) g.update_all(fn.copy_e("e", "e"), fn.max("e", "x"))
return g.dstdata['x'] return g.dstdata["x"]
...@@ -5,9 +5,10 @@ from torch import nn ...@@ -5,9 +5,10 @@ from torch import nn
from torch.nn import init from torch.nn import init
from .... import function as fn from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair from ....utils import expand_as_pair
from ...functional import edge_softmax
# pylint: enable=W0235 # pylint: enable=W0235
class EGATConv(nn.Module): class EGATConv(nn.Module):
...@@ -94,47 +95,63 @@ class EGATConv(nn.Module): ...@@ -94,47 +95,63 @@ class EGATConv(nn.Module):
>>> new_node_feats.shape, new_edge_feats.shape, attentions.shape >>> new_node_feats.shape, new_edge_feats.shape, attentions.shape
(torch.Size([4, 3, 10]), torch.Size([5, 3, 5]), torch.Size([5, 3, 1])) (torch.Size([4, 3, 10]), torch.Size([5, 3, 5]), torch.Size([5, 3, 1]))
""" """
def __init__(self,
def __init__(
self,
in_node_feats, in_node_feats,
in_edge_feats, in_edge_feats,
out_node_feats, out_node_feats,
out_edge_feats, out_edge_feats,
num_heads, num_heads,
bias=True): bias=True,
):
super().__init__() super().__init__()
self._num_heads = num_heads self._num_heads = num_heads
self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(in_node_feats) self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(
in_node_feats
)
self._out_node_feats = out_node_feats self._out_node_feats = out_node_feats
self._out_edge_feats = out_edge_feats self._out_edge_feats = out_edge_feats
if isinstance(in_node_feats, tuple): if isinstance(in_node_feats, tuple):
self.fc_node_src = nn.Linear( self.fc_node_src = nn.Linear(
self._in_src_node_feats, out_node_feats * num_heads, bias=False) self._in_src_node_feats, out_node_feats * num_heads, bias=False
)
self.fc_ni = nn.Linear( self.fc_ni = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False) self._in_src_node_feats, out_edge_feats * num_heads, bias=False
)
self.fc_nj = nn.Linear( self.fc_nj = nn.Linear(
self._in_dst_node_feats, out_edge_feats*num_heads, bias=False) self._in_dst_node_feats, out_edge_feats * num_heads, bias=False
)
else: else:
self.fc_node_src = nn.Linear( self.fc_node_src = nn.Linear(
self._in_src_node_feats, out_node_feats * num_heads, bias=False) self._in_src_node_feats, out_node_feats * num_heads, bias=False
)
self.fc_ni = nn.Linear( self.fc_ni = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False) self._in_src_node_feats, out_edge_feats * num_heads, bias=False
)
self.fc_nj = nn.Linear( self.fc_nj = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False) self._in_src_node_feats, out_edge_feats * num_heads, bias=False
)
self.fc_fij = nn.Linear(in_edge_feats, out_edge_feats*num_heads, bias=False)
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_edge_feats))) self.fc_fij = nn.Linear(
in_edge_feats, out_edge_feats * num_heads, bias=False
)
self.attn = nn.Parameter(
th.FloatTensor(size=(1, num_heads, out_edge_feats))
)
if bias: if bias:
self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_edge_feats,))) self.bias = nn.Parameter(
th.FloatTensor(size=(num_heads * out_edge_feats,))
)
else: else:
self.register_buffer('bias', None) self.register_buffer("bias", None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
""" """
Reinitialize learnable parameters. Reinitialize learnable parameters.
""" """
gain = init.calculate_gain('relu') gain = init.calculate_gain("relu")
init.xavier_normal_(self.fc_node_src.weight, gain=gain) init.xavier_normal_(self.fc_node_src.weight, gain=gain)
init.xavier_normal_(self.fc_ni.weight, gain=gain) init.xavier_normal_(self.fc_ni.weight, gain=gain)
init.xavier_normal_(self.fc_fij.weight, gain=gain) init.xavier_normal_(self.fc_fij.weight, gain=gain)
...@@ -183,13 +200,15 @@ class EGATConv(nn.Module): ...@@ -183,13 +200,15 @@ class EGATConv(nn.Module):
with graph.local_scope(): with graph.local_scope():
if (graph.in_degrees() == 0).any(): if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue.') "calling `g = dgl.add_self_loop(g)` will resolve "
"the issue."
)
# calc edge attention # calc edge attention
# same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats # same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats
...@@ -203,27 +222,31 @@ class EGATConv(nn.Module): ...@@ -203,27 +222,31 @@ class EGATConv(nn.Module):
f_nj = self.fc_nj(nfeats_dst) f_nj = self.fc_nj(nfeats_dst)
f_fij = self.fc_fij(efeats) f_fij = self.fc_fij(efeats)
graph.srcdata.update({'f_ni': f_ni}) graph.srcdata.update({"f_ni": f_ni})
graph.dstdata.update({'f_nj': f_nj}) graph.dstdata.update({"f_nj": f_nj})
# add ni, nj factors # add ni, nj factors
graph.apply_edges(fn.u_add_v('f_ni', 'f_nj', 'f_tmp')) graph.apply_edges(fn.u_add_v("f_ni", "f_nj", "f_tmp"))
# add fij to node factor # add fij to node factor
f_out = graph.edata.pop('f_tmp') + f_fij f_out = graph.edata.pop("f_tmp") + f_fij
if self.bias is not None: if self.bias is not None:
f_out = f_out + self.bias f_out = f_out + self.bias
f_out = nn.functional.leaky_relu(f_out) f_out = nn.functional.leaky_relu(f_out)
f_out = f_out.view(-1, self._num_heads, self._out_edge_feats) f_out = f_out.view(-1, self._num_heads, self._out_edge_feats)
# compute attention factor # compute attention factor
e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1) e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)
graph.edata['a'] = edge_softmax(graph, e) graph.edata["a"] = edge_softmax(graph, e)
graph.srcdata['h_out'] = self.fc_node_src(nfeats_src).view(-1, self._num_heads, graph.srcdata["h_out"] = self.fc_node_src(nfeats_src).view(
self._out_node_feats) -1, self._num_heads, self._out_node_feats
)
# calc weighted sum # calc weighted sum
graph.update_all(fn.u_mul_e('h_out', 'a', 'm'), graph.update_all(
fn.sum('m', 'h_out')) fn.u_mul_e("h_out", "a", "m"), fn.sum("m", "h_out")
)
h_out = graph.dstdata['h_out'].view(-1, self._num_heads, self._out_node_feats) h_out = graph.dstdata["h_out"].view(
-1, self._num_heads, self._out_node_feats
)
if get_attention: if get_attention:
return h_out, f_out, graph.edata.pop('a') return h_out, f_out, graph.edata.pop("a")
else: else:
return h_out, f_out return h_out, f_out
...@@ -4,10 +4,11 @@ import torch as th ...@@ -4,10 +4,11 @@ import torch as th
from torch import nn from torch import nn
from .... import function as fn from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError from ....base import DGLError
from ..utils import Identity
from ....utils import expand_as_pair from ....utils import expand_as_pair
from ...functional import edge_softmax
from ..utils import Identity
# pylint: enable=W0235 # pylint: enable=W0235
class GATConv(nn.Module): class GATConv(nn.Module):
...@@ -130,17 +131,20 @@ class GATConv(nn.Module): ...@@ -130,17 +131,20 @@ class GATConv(nn.Module):
[-0.5945, -0.4801], [-0.5945, -0.4801],
[ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>) [ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>)
""" """
def __init__(self,
def __init__(
self,
in_feats, in_feats,
out_feats, out_feats,
num_heads, num_heads,
feat_drop=0., feat_drop=0.0,
attn_drop=0., attn_drop=0.0,
negative_slope=0.2, negative_slope=0.2,
residual=False, residual=False,
activation=None, activation=None,
allow_zero_in_degree=False, allow_zero_in_degree=False,
bias=True): bias=True,
):
super(GATConv, self).__init__() super(GATConv, self).__init__()
self._num_heads = num_heads self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
...@@ -148,29 +152,39 @@ class GATConv(nn.Module): ...@@ -148,29 +152,39 @@ class GATConv(nn.Module):
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
if isinstance(in_feats, tuple): if isinstance(in_feats, tuple):
self.fc_src = nn.Linear( self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False) self._in_src_feats, out_feats * num_heads, bias=False
)
self.fc_dst = nn.Linear( self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=False) self._in_dst_feats, out_feats * num_heads, bias=False
)
else: else:
self.fc = nn.Linear( self.fc = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False) self._in_src_feats, out_feats * num_heads, bias=False
self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) )
self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) self.attn_l = nn.Parameter(
th.FloatTensor(size=(1, num_heads, out_feats))
)
self.attn_r = nn.Parameter(
th.FloatTensor(size=(1, num_heads, out_feats))
)
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope) self.leaky_relu = nn.LeakyReLU(negative_slope)
if bias: if bias:
self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_feats,))) self.bias = nn.Parameter(
th.FloatTensor(size=(num_heads * out_feats,))
)
else: else:
self.register_buffer('bias', None) self.register_buffer("bias", None)
if residual: if residual:
if self._in_dst_feats != out_feats * num_heads: if self._in_dst_feats != out_feats * num_heads:
self.res_fc = nn.Linear( self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=False) self._in_dst_feats, num_heads * out_feats, bias=False
)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
self.register_buffer('res_fc', None) self.register_buffer("res_fc", None)
self.reset_parameters() self.reset_parameters()
self.activation = activation self.activation = activation
...@@ -186,8 +200,8 @@ class GATConv(nn.Module): ...@@ -186,8 +200,8 @@ class GATConv(nn.Module):
The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The attention weights are using xavier initialization method. The attention weights are using xavier initialization method.
""" """
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
if hasattr(self, 'fc'): if hasattr(self, "fc"):
nn.init.xavier_normal_(self.fc.weight, gain=gain) nn.init.xavier_normal_(self.fc.weight, gain=gain)
else: else:
nn.init.xavier_normal_(self.fc_src.weight, gain=gain) nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
...@@ -251,40 +265,49 @@ class GATConv(nn.Module): ...@@ -251,40 +265,49 @@ class GATConv(nn.Module):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any(): if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple): if isinstance(feat, tuple):
src_prefix_shape = feat[0].shape[:-1] src_prefix_shape = feat[0].shape[:-1]
dst_prefix_shape = feat[1].shape[:-1] dst_prefix_shape = feat[1].shape[:-1]
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'): if not hasattr(self, "fc_src"):
feat_src = self.fc(h_src).view( feat_src = self.fc(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats) *src_prefix_shape, self._num_heads, self._out_feats
)
feat_dst = self.fc(h_dst).view( feat_dst = self.fc(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats) *dst_prefix_shape, self._num_heads, self._out_feats
)
else: else:
feat_src = self.fc_src(h_src).view( feat_src = self.fc_src(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats) *src_prefix_shape, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst).view( feat_dst = self.fc_dst(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats) *dst_prefix_shape, self._num_heads, self._out_feats
)
else: else:
src_prefix_shape = dst_prefix_shape = feat.shape[:-1] src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view( feat_src = feat_dst = self.fc(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats) *src_prefix_shape, self._num_heads, self._out_feats
)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()] h_dst = h_dst[: graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:] dst_prefix_shape = (
graph.number_of_dst_nodes(),
) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
...@@ -297,31 +320,35 @@ class GATConv(nn.Module): ...@@ -297,31 +320,35 @@ class GATConv(nn.Module):
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({'ft': feat_src, 'el': el}) graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({'er': er}) graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(graph.edata.pop('e')) e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax # compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing # message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
fn.sum('m', 'ft')) rst = graph.dstdata["ft"]
rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
# Use -1 rather than self._num_heads to handle broadcasting # Use -1 rather than self._num_heads to handle broadcasting
resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self._out_feats) resval = self.res_fc(h_dst).view(
*dst_prefix_shape, -1, self._out_feats
)
rst = rst + resval rst = rst + resval
# bias # bias
if self.bias is not None: if self.bias is not None:
rst = rst + self.bias.view( rst = rst + self.bias.view(
*((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats) *((1,) * len(dst_prefix_shape)),
self._num_heads,
self._out_feats
)
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
if get_attention: if get_attention:
return rst, graph.edata['a'] return rst, graph.edata["a"]
else: else:
return rst return rst
...@@ -85,24 +85,28 @@ class GINConv(nn.Module): ...@@ -85,24 +85,28 @@ class GINConv(nn.Module):
[2.5011, 0.0000, 0.0089, 2.0541, 0.8262, 0.0000, 0.0000, 0.1371, 0.0000, [2.5011, 0.0000, 0.0089, 2.0541, 0.8262, 0.0000, 0.0000, 0.1371, 0.0000,
0.0000]], grad_fn=<ReluBackward0>) 0.0000]], grad_fn=<ReluBackward0>)
""" """
def __init__(self,
def __init__(
self,
apply_func=None, apply_func=None,
aggregator_type='sum', aggregator_type="sum",
init_eps=0, init_eps=0,
learn_eps=False, learn_eps=False,
activation=None): activation=None,
):
super(GINConv, self).__init__() super(GINConv, self).__init__()
self.apply_func = apply_func self.apply_func = apply_func
self._aggregator_type = aggregator_type self._aggregator_type = aggregator_type
self.activation = activation self.activation = activation
if aggregator_type not in ('sum', 'max', 'mean'): if aggregator_type not in ("sum", "max", "mean"):
raise KeyError( raise KeyError(
'Aggregator type {} not recognized.'.format(aggregator_type)) "Aggregator type {} not recognized.".format(aggregator_type)
)
# to specify whether eps is trainable or not. # to specify whether eps is trainable or not.
if learn_eps: if learn_eps:
self.eps = th.nn.Parameter(th.FloatTensor([init_eps])) self.eps = th.nn.Parameter(th.FloatTensor([init_eps]))
else: else:
self.register_buffer('eps', th.FloatTensor([init_eps])) self.register_buffer("eps", th.FloatTensor([init_eps]))
def forward(self, graph, feat, edge_weight=None): def forward(self, graph, feat, edge_weight=None):
r""" r"""
...@@ -136,16 +140,16 @@ class GINConv(nn.Module): ...@@ -136,16 +140,16 @@ class GINConv(nn.Module):
""" """
_reducer = getattr(fn, self._aggregator_type) _reducer = getattr(fn, self._aggregator_type)
with graph.local_scope(): with graph.local_scope():
aggregate_fn = fn.copy_u('h', 'm') aggregate_fn = fn.copy_u("h", "m")
if edge_weight is not None: if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges() assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight graph.edata["_edge_weight"] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm') aggregate_fn = fn.u_mul_e("h", "_edge_weight", "m")
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.update_all(aggregate_fn, _reducer('m', 'neigh')) graph.update_all(aggregate_fn, _reducer("m", "neigh"))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh'] rst = (1 + self.eps) * feat_dst + graph.dstdata["neigh"]
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
# activation # activation
......
...@@ -6,8 +6,8 @@ from torch.nn import init ...@@ -6,8 +6,8 @@ from torch.nn import init
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from ..utils import Identity
from ....utils import expand_as_pair from ....utils import expand_as_pair
from ..utils import Identity
class GMMConv(nn.Module): class GMMConv(nn.Module):
...@@ -103,45 +103,54 @@ class GMMConv(nn.Module): ...@@ -103,45 +103,54 @@ class GMMConv(nn.Module):
[-0.1377, -0.1943], [-0.1377, -0.1943],
[-0.1107, -0.1559]], grad_fn=<AddBackward0>) [-0.1107, -0.1559]], grad_fn=<AddBackward0>)
""" """
def __init__(self,
def __init__(
self,
in_feats, in_feats,
out_feats, out_feats,
dim, dim,
n_kernels, n_kernels,
aggregator_type='sum', aggregator_type="sum",
residual=False, residual=False,
bias=True, bias=True,
allow_zero_in_degree=False): allow_zero_in_degree=False,
):
super(GMMConv, self).__init__() super(GMMConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
self._dim = dim self._dim = dim
self._n_kernels = n_kernels self._n_kernels = n_kernels
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
if aggregator_type == 'sum': if aggregator_type == "sum":
self._reducer = fn.sum self._reducer = fn.sum
elif aggregator_type == 'mean': elif aggregator_type == "mean":
self._reducer = fn.mean self._reducer = fn.mean
elif aggregator_type == 'max': elif aggregator_type == "max":
self._reducer = fn.max self._reducer = fn.max
else: else:
raise KeyError("Aggregator type {} not recognized.".format(aggregator_type)) raise KeyError(
"Aggregator type {} not recognized.".format(aggregator_type)
)
self.mu = nn.Parameter(th.Tensor(n_kernels, dim)) self.mu = nn.Parameter(th.Tensor(n_kernels, dim))
self.inv_sigma = nn.Parameter(th.Tensor(n_kernels, dim)) self.inv_sigma = nn.Parameter(th.Tensor(n_kernels, dim))
self.fc = nn.Linear(self._in_src_feats, n_kernels * out_feats, bias=False) self.fc = nn.Linear(
self._in_src_feats, n_kernels * out_feats, bias=False
)
if residual: if residual:
if self._in_dst_feats != out_feats: if self._in_dst_feats != out_feats:
self.res_fc = nn.Linear(self._in_dst_feats, out_feats, bias=False) self.res_fc = nn.Linear(
self._in_dst_feats, out_feats, bias=False
)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
self.register_buffer('res_fc', None) self.register_buffer("res_fc", None)
if bias: if bias:
self.bias = nn.Parameter(th.Tensor(out_feats)) self.bias = nn.Parameter(th.Tensor(out_feats))
else: else:
self.register_buffer('bias', None) self.register_buffer("bias", None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -158,7 +167,7 @@ class GMMConv(nn.Module): ...@@ -158,7 +167,7 @@ class GMMConv(nn.Module):
The mu weight is initialized using normal distribution and The mu weight is initialized using normal distribution and
inv_sigma is initialized with constant value 1.0. inv_sigma is initialized with constant value 1.0.
""" """
gain = init.calculate_gain('relu') gain = init.calculate_gain("relu")
init.xavier_normal_(self.fc.weight, gain=gain) init.xavier_normal_(self.fc.weight, gain=gain)
if isinstance(self.res_fc, nn.Linear): if isinstance(self.res_fc, nn.Linear):
init.xavier_normal_(self.res_fc.weight, gain=gain) init.xavier_normal_(self.res_fc.weight, gain=gain)
...@@ -218,27 +227,38 @@ class GMMConv(nn.Module): ...@@ -218,27 +227,38 @@ class GMMConv(nn.Module):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any(): if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = self.fc(feat_src).view(-1, self._n_kernels, self._out_feats) graph.srcdata["h"] = self.fc(feat_src).view(
-1, self._n_kernels, self._out_feats
)
E = graph.number_of_edges() E = graph.number_of_edges()
# compute gaussian weight # compute gaussian weight
gaussian = -0.5 * ((pseudo.view(E, 1, self._dim) - gaussian = -0.5 * (
self.mu.view(1, self._n_kernels, self._dim)) ** 2) (
gaussian = gaussian * (self.inv_sigma.view(1, self._n_kernels, self._dim) ** 2) pseudo.view(E, 1, self._dim)
- self.mu.view(1, self._n_kernels, self._dim)
)
** 2
)
gaussian = gaussian * (
self.inv_sigma.view(1, self._n_kernels, self._dim) ** 2
)
gaussian = th.exp(gaussian.sum(dim=-1, keepdim=True)) # (E, K, 1) gaussian = th.exp(gaussian.sum(dim=-1, keepdim=True)) # (E, K, 1)
graph.edata['w'] = gaussian graph.edata["w"] = gaussian
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h')) graph.update_all(fn.u_mul_e("h", "w", "m"), self._reducer("m", "h"))
rst = graph.dstdata['h'].sum(1) rst = graph.dstdata["h"].sum(1)
# residual connection # residual connection
if self.res_fc is not None: if self.res_fc is not None:
rst = rst + self.res_fc(feat_dst) rst = rst + self.res_fc(feat_dst)
......
...@@ -6,10 +6,11 @@ from torch.nn import init ...@@ -6,10 +6,11 @@ from torch.nn import init
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair
from ....transforms import reverse
from ....convert import block_to_graph from ....convert import block_to_graph
from ....heterograph import DGLBlock from ....heterograph import DGLBlock
from ....transforms import reverse
from ....utils import expand_as_pair
class EdgeWeightNorm(nn.Module): class EdgeWeightNorm(nn.Module):
r"""This module normalizes positive scalar edge weights on a graph r"""This module normalizes positive scalar edge weights on a graph
...@@ -59,7 +60,8 @@ class EdgeWeightNorm(nn.Module): ...@@ -59,7 +60,8 @@ class EdgeWeightNorm(nn.Module):
[-1.3658, -0.8674], [-1.3658, -0.8674],
[-0.8323, -0.5286]], grad_fn=<AddBackward0>) [-0.8323, -0.5286]], grad_fn=<AddBackward0>)
""" """
def __init__(self, norm='both', eps=0.):
def __init__(self, norm="both", eps=0.0):
super(EdgeWeightNorm, self).__init__() super(EdgeWeightNorm, self).__init__()
self._norm = norm self._norm = norm
self._eps = eps self._eps = eps
...@@ -99,42 +101,57 @@ class EdgeWeightNorm(nn.Module): ...@@ -99,42 +101,57 @@ class EdgeWeightNorm(nn.Module):
if isinstance(graph, DGLBlock): if isinstance(graph, DGLBlock):
graph = block_to_graph(graph) graph = block_to_graph(graph)
if len(edge_weight.shape) > 1: if len(edge_weight.shape) > 1:
raise DGLError('Currently the normalization is only defined ' raise DGLError(
'on scalar edge weight. Please customize the ' "Currently the normalization is only defined "
'normalization for your high-dimensional weights.') "on scalar edge weight. Please customize the "
if self._norm == 'both' and th.any(edge_weight <= 0).item(): "normalization for your high-dimensional weights."
raise DGLError('Non-positive edge weight detected with `norm="both"`. ' )
'This leads to square root of zero or negative values.') if self._norm == "both" and th.any(edge_weight <= 0).item():
raise DGLError(
'Non-positive edge weight detected with `norm="both"`. '
"This leads to square root of zero or negative values."
)
dev = graph.device dev = graph.device
dtype = edge_weight.dtype dtype = edge_weight.dtype
graph.srcdata['_src_out_w'] = th.ones( graph.srcdata["_src_out_w"] = th.ones(
graph.number_of_src_nodes(), dtype=dtype, device=dev) graph.number_of_src_nodes(), dtype=dtype, device=dev
graph.dstdata['_dst_in_w'] = th.ones( )
graph.number_of_dst_nodes(), dtype=dtype, device=dev) graph.dstdata["_dst_in_w"] = th.ones(
graph.edata['_edge_w'] = edge_weight graph.number_of_dst_nodes(), dtype=dtype, device=dev
)
if self._norm == 'both': graph.edata["_edge_w"] = edge_weight
if self._norm == "both":
reversed_g = reverse(graph) reversed_g = reverse(graph)
reversed_g.edata['_edge_w'] = edge_weight reversed_g.edata["_edge_w"] = edge_weight
reversed_g.update_all(fn.copy_e('_edge_w', 'm'), fn.sum('m', 'out_weight')) reversed_g.update_all(
degs = reversed_g.dstdata['out_weight'] + self._eps fn.copy_e("_edge_w", "m"), fn.sum("m", "out_weight")
)
degs = reversed_g.dstdata["out_weight"] + self._eps
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
graph.srcdata['_src_out_w'] = norm graph.srcdata["_src_out_w"] = norm
if self._norm != 'none': if self._norm != "none":
graph.update_all(fn.copy_e('_edge_w', 'm'), fn.sum('m', 'in_weight')) graph.update_all(
degs = graph.dstdata['in_weight'] + self._eps fn.copy_e("_edge_w", "m"), fn.sum("m", "in_weight")
if self._norm == 'both': )
degs = graph.dstdata["in_weight"] + self._eps
if self._norm == "both":
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
else: else:
norm = 1.0 / degs norm = 1.0 / degs
graph.dstdata['_dst_in_w'] = norm graph.dstdata["_dst_in_w"] = norm
graph.apply_edges(
lambda e: {
"_norm_edge_weights": e.src["_src_out_w"]
* e.dst["_dst_in_w"]
* e.data["_edge_w"]
}
)
return graph.edata["_norm_edge_weights"]
graph.apply_edges(lambda e: {'_norm_edge_weights': e.src['_src_out_w'] * \
e.dst['_dst_in_w'] * \
e.data['_edge_w']})
return graph.edata['_norm_edge_weights']
# pylint: disable=W0235 # pylint: disable=W0235
class GraphConv(nn.Module): class GraphConv(nn.Module):
...@@ -266,18 +283,23 @@ class GraphConv(nn.Module): ...@@ -266,18 +283,23 @@ class GraphConv(nn.Module):
[-0.5287, 0.8235], [-0.5287, 0.8235],
[-0.2994, 0.6106]], grad_fn=<AddBackward0>) [-0.2994, 0.6106]], grad_fn=<AddBackward0>)
""" """
def __init__(self,
def __init__(
self,
in_feats, in_feats,
out_feats, out_feats,
norm='both', norm="both",
weight=True, weight=True,
bias=True, bias=True,
activation=None, activation=None,
allow_zero_in_degree=False): allow_zero_in_degree=False,
):
super(GraphConv, self).__init__() super(GraphConv, self).__init__()
if norm not in ('none', 'both', 'right', 'left'): if norm not in ("none", "both", "right", "left"):
raise DGLError('Invalid norm value. Must be either "none", "both", "right" or "left".' raise DGLError(
' But got "{}".'.format(norm)) 'Invalid norm value. Must be either "none", "both", "right" or "left".'
' But got "{}".'.format(norm)
)
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._norm = norm self._norm = norm
...@@ -286,12 +308,12 @@ class GraphConv(nn.Module): ...@@ -286,12 +308,12 @@ class GraphConv(nn.Module):
if weight: if weight:
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats)) self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))
else: else:
self.register_parameter('weight', None) self.register_parameter("weight", None)
if bias: if bias:
self.bias = nn.Parameter(th.Tensor(out_feats)) self.bias = nn.Parameter(th.Tensor(out_feats))
else: else:
self.register_parameter('bias', None) self.register_parameter("bias", None)
self.reset_parameters() self.reset_parameters()
...@@ -383,26 +405,28 @@ class GraphConv(nn.Module): ...@@ -383,26 +405,28 @@ class GraphConv(nn.Module):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any(): if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
aggregate_fn = fn.copy_u('h', 'm') "suppress the check and let the code run."
)
aggregate_fn = fn.copy_u("h", "m")
if edge_weight is not None: if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges() assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight graph.edata["_edge_weight"] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm') aggregate_fn = fn.u_mul_e("h", "_edge_weight", "m")
# (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite. # (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
if self._norm in ['left', 'both']: if self._norm in ["left", "both"]:
degs = graph.out_degrees().to(feat_src).clamp(min=1) degs = graph.out_degrees().to(feat_src).clamp(min=1)
if self._norm == 'both': if self._norm == "both":
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
else: else:
norm = 1.0 / degs norm = 1.0 / degs
...@@ -412,9 +436,11 @@ class GraphConv(nn.Module): ...@@ -412,9 +436,11 @@ class GraphConv(nn.Module):
if weight is not None: if weight is not None:
if self.weight is not None: if self.weight is not None:
raise DGLError('External weight is provided while at the same time the' raise DGLError(
' module has defined its own weight parameter. Please' "External weight is provided while at the same time the"
' create the module with flag weight=False.') " module has defined its own weight parameter. Please"
" create the module with flag weight=False."
)
else: else:
weight = self.weight weight = self.weight
...@@ -422,20 +448,20 @@ class GraphConv(nn.Module): ...@@ -422,20 +448,20 @@ class GraphConv(nn.Module):
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
if weight is not None: if weight is not None:
feat_src = th.matmul(feat_src, weight) feat_src = th.matmul(feat_src, weight)
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.update_all(aggregate_fn, fn.sum(msg='m', out='h')) graph.update_all(aggregate_fn, fn.sum(msg="m", out="h"))
rst = graph.dstdata['h'] rst = graph.dstdata["h"]
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.update_all(aggregate_fn, fn.sum(msg='m', out='h')) graph.update_all(aggregate_fn, fn.sum(msg="m", out="h"))
rst = graph.dstdata['h'] rst = graph.dstdata["h"]
if weight is not None: if weight is not None:
rst = th.matmul(rst, weight) rst = th.matmul(rst, weight)
if self._norm in ['right', 'both']: if self._norm in ["right", "both"]:
degs = graph.in_degrees().to(feat_dst).clamp(min=1) degs = graph.in_degrees().to(feat_dst).clamp(min=1)
if self._norm == 'both': if self._norm == "both":
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
else: else:
norm = 1.0 / degs norm = 1.0 / degs
...@@ -455,8 +481,8 @@ class GraphConv(nn.Module): ...@@ -455,8 +481,8 @@ class GraphConv(nn.Module):
"""Set the extra representation of the module, """Set the extra representation of the module,
which will come into effect when printing the model. which will come into effect when printing the model.
""" """
summary = 'in={_in_feats}, out={_out_feats}' summary = "in={_in_feats}, out={_out_feats}"
summary += ', normalization={_norm}' summary += ", normalization={_norm}"
if '_activation' in self.__dict__: if "_activation" in self.__dict__:
summary += ', activation={_activation}' summary += ", activation={_activation}"
return summary.format(**self.__dict__) return summary.format(**self.__dict__)
...@@ -144,9 +144,9 @@ class HGTConv(nn.Module): ...@@ -144,9 +144,9 @@ class HGTConv(nn.Module):
self.presorted = presorted self.presorted = presorted
if g.is_block: if g.is_block:
x_src = x x_src = x
x_dst = x[:g.num_dst_nodes()] x_dst = x[: g.num_dst_nodes()]
srcntype = ntype srcntype = ntype
dstntype = ntype[:g.num_dst_nodes()] dstntype = ntype[: g.num_dst_nodes()]
else: else:
x_src = x x_src = x
x_dst = x x_dst = x
......
...@@ -5,8 +5,8 @@ from torch import nn ...@@ -5,8 +5,8 @@ from torch import nn
from torch.nn import init from torch.nn import init
from .... import function as fn from .... import function as fn
from ..utils import Identity
from ....utils import expand_as_pair from ....utils import expand_as_pair
from ..utils import Identity
class NNConv(nn.Module): class NNConv(nn.Module):
...@@ -84,37 +84,44 @@ class NNConv(nn.Module): ...@@ -84,37 +84,44 @@ class NNConv(nn.Module):
[ 0.1261, -0.0155], [ 0.1261, -0.0155],
[-0.6568, 0.5042]], grad_fn=<AddBackward0>) [-0.6568, 0.5042]], grad_fn=<AddBackward0>)
""" """
def __init__(self,
def __init__(
self,
in_feats, in_feats,
out_feats, out_feats,
edge_func, edge_func,
aggregator_type='mean', aggregator_type="mean",
residual=False, residual=False,
bias=True): bias=True,
):
super(NNConv, self).__init__() super(NNConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
self.edge_func = edge_func self.edge_func = edge_func
if aggregator_type == 'sum': if aggregator_type == "sum":
self.reducer = fn.sum self.reducer = fn.sum
elif aggregator_type == 'mean': elif aggregator_type == "mean":
self.reducer = fn.mean self.reducer = fn.mean
elif aggregator_type == 'max': elif aggregator_type == "max":
self.reducer = fn.max self.reducer = fn.max
else: else:
raise KeyError('Aggregator type {} not recognized: '.format(aggregator_type)) raise KeyError(
"Aggregator type {} not recognized: ".format(aggregator_type)
)
self._aggre_type = aggregator_type self._aggre_type = aggregator_type
if residual: if residual:
if self._in_dst_feats != out_feats: if self._in_dst_feats != out_feats:
self.res_fc = nn.Linear(self._in_dst_feats, out_feats, bias=False) self.res_fc = nn.Linear(
self._in_dst_feats, out_feats, bias=False
)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
self.register_buffer('res_fc', None) self.register_buffer("res_fc", None)
if bias: if bias:
self.bias = nn.Parameter(th.Tensor(out_feats)) self.bias = nn.Parameter(th.Tensor(out_feats))
else: else:
self.register_buffer('bias', None) self.register_buffer("bias", None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -129,7 +136,7 @@ class NNConv(nn.Module): ...@@ -129,7 +136,7 @@ class NNConv(nn.Module):
The model parameters are initialized using Glorot uniform initialization The model parameters are initialized using Glorot uniform initialization
and the bias is initialized to be zero. and the bias is initialized to be zero.
""" """
gain = init.calculate_gain('relu') gain = init.calculate_gain("relu")
if self.bias is not None: if self.bias is not None:
nn.init.zeros_(self.bias) nn.init.zeros_(self.bias)
if isinstance(self.res_fc, nn.Linear): if isinstance(self.res_fc, nn.Linear):
...@@ -161,12 +168,16 @@ class NNConv(nn.Module): ...@@ -161,12 +168,16 @@ class NNConv(nn.Module):
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
# (n, d_in, 1) # (n, d_in, 1)
graph.srcdata['h'] = feat_src.unsqueeze(-1) graph.srcdata["h"] = feat_src.unsqueeze(-1)
# (n, d_in, d_out) # (n, d_in, d_out)
graph.edata['w'] = self.edge_func(efeat).view(-1, self._in_src_feats, self._out_feats) graph.edata["w"] = self.edge_func(efeat).view(
-1, self._in_src_feats, self._out_feats
)
# (n, d_in, d_out) # (n, d_in, d_out)
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh')) graph.update_all(
rst = graph.dstdata['neigh'].sum(dim=1) # (n, d_out) fn.u_mul_e("h", "w", "m"), self.reducer("m", "neigh")
)
rst = graph.dstdata["neigh"].sum(dim=1) # (n, d_out)
# residual connection # residual connection
if self.res_fc is not None: if self.res_fc is not None:
rst = rst + self.res_fc(feat_dst) rst = rst + self.res_fc(feat_dst)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment