Unverified Commit 63ac788f authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

auto-reformat-nn (#5319)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 0b3a447b
......@@ -6,8 +6,7 @@ import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn
from .... import broadcast_nodes
from .... import function as fn
from .... import broadcast_nodes, function as fn
from ....base import dgl_warning
......
......@@ -98,20 +98,21 @@ class EdgeConv(nn.Block):
[-1.015364 0.78919804]]
<NDArray 4x2 @cpu(0)>
"""
def __init__(self,
in_feat,
out_feat,
batch_norm=False,
allow_zero_in_degree=False):
def __init__(
self, in_feat, out_feat, batch_norm=False, allow_zero_in_degree=False
):
super(EdgeConv, self).__init__()
self.batch_norm = batch_norm
self._allow_zero_in_degree = allow_zero_in_degree
with self.name_scope():
self.theta = nn.Dense(out_feat, in_units=in_feat,
weight_initializer=mx.init.Xavier())
self.phi = nn.Dense(out_feat, in_units=in_feat,
weight_initializer=mx.init.Xavier())
self.theta = nn.Dense(
out_feat, in_units=in_feat, weight_initializer=mx.init.Xavier()
)
self.phi = nn.Dense(
out_feat, in_units=in_feat, weight_initializer=mx.init.Xavier()
)
if batch_norm:
self.bn = nn.BatchNorm(in_channels=out_feat)
......@@ -164,26 +165,28 @@ class EdgeConv(nn.Block):
with g.local_scope():
if not self._allow_zero_in_degree:
if g.in_degrees().min() == 0:
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
h_src, h_dst = expand_as_pair(h, g)
g.srcdata['x'] = h_src
g.dstdata['x'] = h_dst
g.apply_edges(fn.v_sub_u('x', 'x', 'theta'))
g.edata['theta'] = self.theta(g.edata['theta'])
g.dstdata['phi'] = self.phi(g.dstdata['x'])
g.srcdata["x"] = h_src
g.dstdata["x"] = h_dst
g.apply_edges(fn.v_sub_u("x", "x", "theta"))
g.edata["theta"] = self.theta(g.edata["theta"])
g.dstdata["phi"] = self.phi(g.dstdata["x"])
if not self.batch_norm:
g.update_all(fn.e_add_v('theta', 'phi', 'e'), fn.max('e', 'x'))
g.update_all(fn.e_add_v("theta", "phi", "e"), fn.max("e", "x"))
else:
g.apply_edges(fn.e_add_v('theta', 'phi', 'e'))
g.edata['e'] = self.bn(g.edata['e'])
g.update_all(fn.copy_e('e', 'm'), fn.max('m', 'x'))
return g.dstdata['x']
g.apply_edges(fn.e_add_v("theta", "phi", "e"))
g.edata["e"] = self.bn(g.edata["e"])
g.update_all(fn.copy_e("e", "m"), fn.max("m", "x"))
return g.dstdata["x"]
"""MXNet modules for graph attention networks(GAT)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import mxnet as mx
from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import Identity
from .... import function as fn
from ....base import DGLError
from ...functional import edge_softmax
from ....utils import expand_as_pair
from ...functional import edge_softmax
#pylint: enable=W0235
# pylint: enable=W0235
class GATConv(nn.Block):
r"""Graph attention layer from `Graph Attention Network
<https://arxiv.org/pdf/1710.10903.pdf>`__
......@@ -134,16 +136,19 @@ class GATConv(nn.Block):
[-1.9325689 1.3824553 ]]]
<NDArray 4x3x2 @cpu(0)>
"""
def __init__(self,
in_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False):
def __init__(
self,
in_feats,
out_feats,
num_heads,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
):
super(GATConv, self).__init__()
self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
......@@ -152,31 +157,48 @@ class GATConv(nn.Block):
self._allow_zero_in_degree = allow_zero_in_degree
with self.name_scope():
if isinstance(in_feats, tuple):
self.fc_src = nn.Dense(out_feats * num_heads, use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_src_feats)
self.fc_dst = nn.Dense(out_feats * num_heads, use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_dst_feats)
self.fc_src = nn.Dense(
out_feats * num_heads,
use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_src_feats,
)
self.fc_dst = nn.Dense(
out_feats * num_heads,
use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_dst_feats,
)
else:
self.fc = nn.Dense(out_feats * num_heads, use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats)
self.attn_l = self.params.get('attn_l',
shape=(1, num_heads, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
self.attn_r = self.params.get('attn_r',
shape=(1, num_heads, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
self.fc = nn.Dense(
out_feats * num_heads,
use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats,
)
self.attn_l = self.params.get(
"attn_l",
shape=(1, num_heads, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
self.attn_r = self.params.get(
"attn_r",
shape=(1, num_heads, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual:
if in_feats != out_feats:
self.res_fc = nn.Dense(out_feats * num_heads, use_bias=False,
weight_initializer=mx.init.Xavier(
magnitude=math.sqrt(2.0)),
in_units=in_feats)
self.res_fc = nn.Dense(
out_feats * num_heads,
use_bias=False,
weight_initializer=mx.init.Xavier(
magnitude=math.sqrt(2.0)
),
in_units=in_feats,
)
else:
self.res_fc = Identity()
else:
......@@ -235,15 +257,17 @@ class GATConv(nn.Block):
with graph.local_scope():
if not self._allow_zero_in_degree:
if graph.in_degrees().min() == 0:
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple):
src_prefix_shape = feat[0].shape[:-1]
......@@ -251,22 +275,27 @@ class GATConv(nn.Block):
feat_dim = feat[0].shape[-1]
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
if not hasattr(self, "fc_src"):
self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = self.fc_src(h_src.reshape(-1, feat_dim)).reshape(
*src_prefix_shape, self._num_heads, self._out_feats)
*src_prefix_shape, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst.reshape(-1, feat_dim)).reshape(
*dst_prefix_shape, self._num_heads, self._out_feats)
*dst_prefix_shape, self._num_heads, self._out_feats
)
else:
src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
feat_dim = feat[0].shape[-1]
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src.reshape(-1, feat_dim)).reshape(
*src_prefix_shape, self._num_heads, self._out_feats)
feat_src = feat_dst = self.fc(
h_src.reshape(-1, feat_dim)
).reshape(*src_prefix_shape, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_dst = h_dst[: graph.number_of_dst_nodes()]
dst_prefix_shape = (
graph.number_of_dst_nodes(),
) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
......@@ -277,28 +306,36 @@ class GATConv(nn.Block):
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l.data(feat_src.context)).sum(axis=-1).expand_dims(-1)
er = (feat_dst * self.attn_r.data(feat_src.context)).sum(axis=-1).expand_dims(-1)
graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
el = (
(feat_src * self.attn_l.data(feat_src.context))
.sum(axis=-1)
.expand_dims(-1)
)
er = (
(feat_dst * self.attn_r.data(feat_src.context))
.sum(axis=-1)
.expand_dims(-1)
)
graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))
graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.dstdata['ft']
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"]
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst.reshape(-1, feat_dim)).reshape(
*dst_prefix_shape, -1, self._out_feats)
*dst_prefix_shape, -1, self._out_feats
)
rst = rst + resval
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
return rst, graph.edata["a"]
else:
return rst
"""Torch Module for GMM Conv"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn
......@@ -107,15 +108,18 @@ class GMMConv(nn.Block):
[-0.1005067 -0.09494358]]
<NDArray 4x2 @cpu(0)>
"""
def __init__(self,
in_feats,
out_feats,
dim,
n_kernels,
aggregator_type='sum',
residual=False,
bias=True,
allow_zero_in_degree=False):
def __init__(
self,
in_feats,
out_feats,
dim,
n_kernels,
aggregator_type="sum",
residual=False,
bias=True,
allow_zero_in_degree=False,
):
super(GMMConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
......@@ -123,38 +127,44 @@ class GMMConv(nn.Block):
self._dim = dim
self._n_kernels = n_kernels
self._allow_zero_in_degree = allow_zero_in_degree
if aggregator_type == 'sum':
if aggregator_type == "sum":
self._reducer = fn.sum
elif aggregator_type == 'mean':
elif aggregator_type == "mean":
self._reducer = fn.mean
elif aggregator_type == 'max':
elif aggregator_type == "max":
self._reducer = fn.max
else:
raise KeyError("Aggregator type {} not recognized.".format(aggregator_type))
raise KeyError(
"Aggregator type {} not recognized.".format(aggregator_type)
)
with self.name_scope():
self.mu = self.params.get('mu',
shape=(n_kernels, dim),
init=mx.init.Normal(0.1))
self.inv_sigma = self.params.get('inv_sigma',
shape=(n_kernels, dim),
init=mx.init.Constant(1))
self.fc = nn.Dense(n_kernels * out_feats,
in_units=self._in_src_feats,
use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)))
self.mu = self.params.get(
"mu", shape=(n_kernels, dim), init=mx.init.Normal(0.1)
)
self.inv_sigma = self.params.get(
"inv_sigma", shape=(n_kernels, dim), init=mx.init.Constant(1)
)
self.fc = nn.Dense(
n_kernels * out_feats,
in_units=self._in_src_feats,
use_bias=False,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
if residual:
if self._in_dst_feats != out_feats:
self.res_fc = nn.Dense(out_feats, in_units=self._in_dst_feats, use_bias=False)
self.res_fc = nn.Dense(
out_feats, in_units=self._in_dst_feats, use_bias=False
)
else:
self.res_fc = Identity()
else:
self.res_fc = None
if bias:
self.bias = self.params.get('bias',
shape=(out_feats,),
init=mx.init.Zero())
self.bias = self.params.get(
"bias", shape=(out_feats,), init=mx.init.Zero()
)
else:
self.bias = None
......@@ -208,32 +218,44 @@ class GMMConv(nn.Block):
"""
if not self._allow_zero_in_degree:
if graph.in_degrees().min() == 0:
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
feat_src, feat_dst = expand_as_pair(feat, graph)
with graph.local_scope():
graph.srcdata['h'] = self.fc(feat_src).reshape(
-1, self._n_kernels, self._out_feats)
graph.srcdata["h"] = self.fc(feat_src).reshape(
-1, self._n_kernels, self._out_feats
)
E = graph.number_of_edges()
# compute gaussian weight
gaussian = -0.5 * ((pseudo.reshape(E, 1, self._dim) -
self.mu.data(feat_src.context)
.reshape(1, self._n_kernels, self._dim)) ** 2)
gaussian = gaussian *\
(self.inv_sigma.data(feat_src.context)
.reshape(1, self._n_kernels, self._dim) ** 2)
gaussian = nd.exp(gaussian.sum(axis=-1, keepdims=True)) # (E, K, 1)
graph.edata['w'] = gaussian
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h'))
rst = graph.dstdata['h'].sum(1)
gaussian = -0.5 * (
(
pseudo.reshape(E, 1, self._dim)
- self.mu.data(feat_src.context).reshape(
1, self._n_kernels, self._dim
)
)
** 2
)
gaussian = gaussian * (
self.inv_sigma.data(feat_src.context).reshape(
1, self._n_kernels, self._dim
)
** 2
)
gaussian = nd.exp(gaussian.sum(axis=-1, keepdims=True)) # (E, K, 1)
graph.edata["w"] = gaussian
graph.update_all(fn.u_mul_e("h", "w", "m"), self._reducer("m", "h"))
rst = graph.dstdata["h"].sum(1)
# residual connection
if self.res_fc is not None:
rst = rst + self.res_fc(feat_dst)
......
......@@ -9,6 +9,7 @@ from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair
class GraphConv(gluon.Block):
r"""Graph convolutional layer from `Semi-Supervised Classification with Graph Convolutional
Networks <https://arxiv.org/abs/1609.02907>`__
......@@ -133,18 +134,23 @@ class GraphConv(gluon.Block):
[ 0.26967263 0.308129 ]]
<NDArray 4x2 @cpu(0)>
"""
def __init__(self,
in_feats,
out_feats,
norm='both',
weight=True,
bias=True,
activation=None,
allow_zero_in_degree=False):
def __init__(
self,
in_feats,
out_feats,
norm="both",
weight=True,
bias=True,
activation=None,
allow_zero_in_degree=False,
):
super(GraphConv, self).__init__()
if norm not in ('none', 'both', 'right', 'left'):
raise DGLError('Invalid norm value. Must be either "none", "both", "right" or "left".'
' But got "{}".'.format(norm))
if norm not in ("none", "both", "right", "left"):
raise DGLError(
'Invalid norm value. Must be either "none", "both", "right" or "left".'
' But got "{}".'.format(norm)
)
self._in_feats = in_feats
self._out_feats = out_feats
self._norm = norm
......@@ -152,14 +158,18 @@ class GraphConv(gluon.Block):
with self.name_scope():
if weight:
self.weight = self.params.get('weight', shape=(in_feats, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
self.weight = self.params.get(
"weight",
shape=(in_feats, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
else:
self.weight = None
if bias:
self.bias = self.params.get('bias', shape=(out_feats,),
init=mx.init.Zero())
self.bias = self.params.get(
"bias", shape=(out_feats,), init=mx.init.Zero()
)
else:
self.bias = None
......@@ -225,21 +235,27 @@ class GraphConv(gluon.Block):
with graph.local_scope():
if not self._allow_zero_in_degree:
if graph.in_degrees().min() == 0:
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
feat_src, feat_dst = expand_as_pair(feat, graph)
if self._norm in ['both', 'left']:
degs = graph.out_degrees().as_in_context(feat_dst.context).astype('float32')
if self._norm in ["both", "left"]:
degs = (
graph.out_degrees()
.as_in_context(feat_dst.context)
.astype("float32")
)
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
if self._norm == 'both':
if self._norm == "both":
norm = mx.nd.power(degs, -0.5)
else:
norm = 1.0 / degs
......@@ -247,12 +263,13 @@ class GraphConv(gluon.Block):
norm = norm.reshape(shp)
feat_src = feat_src * norm
if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
raise DGLError(
"External weight is provided while at the same time the"
" module has defined its own weight parameter. Please"
" create the module with flag weight=False."
)
else:
weight = self.weight.data(feat_src.context)
......@@ -260,23 +277,29 @@ class GraphConv(gluon.Block):
# mult W first to reduce the feature size for aggregation.
if weight is not None:
feat_src = mx.nd.dot(feat_src, weight)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h')
graph.srcdata["h"] = feat_src
graph.update_all(
fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")
)
rst = graph.dstdata.pop("h")
else:
# aggregate first then mult W
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h')
graph.srcdata["h"] = feat_src
graph.update_all(
fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")
)
rst = graph.dstdata.pop("h")
if weight is not None:
rst = mx.nd.dot(rst, weight)
if self._norm in ['both', 'right']:
degs = graph.in_degrees().as_in_context(feat_dst.context).astype('float32')
if self._norm in ["both", "right"]:
degs = (
graph.in_degrees()
.as_in_context(feat_dst.context)
.astype("float32")
)
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
if self._norm == 'both':
if self._norm == "both":
norm = mx.nd.power(degs, -0.5)
else:
norm = 1.0 / degs
......@@ -293,9 +316,9 @@ class GraphConv(gluon.Block):
return rst
def __repr__(self):
summary = 'GraphConv('
summary += 'in={:d}, out={:d}, normalization={}, activation={}'.format(
self._in_feats, self._out_feats,
self._norm, self._activation)
summary += ')'
summary = "GraphConv("
summary += "in={:d}, out={:d}, normalization={}, activation={}".format(
self._in_feats, self._out_feats, self._norm, self._activation
)
summary += ")"
return summary
......@@ -89,24 +89,29 @@ class NNConv(nn.Block):
[ 0.24425688 0.3238042 ]]
<NDArray 4x2 @cpu(0)>
"""
def __init__(self,
in_feats,
out_feats,
edge_func,
aggregator_type,
residual=False,
bias=True):
def __init__(
self,
in_feats,
out_feats,
edge_func,
aggregator_type,
residual=False,
bias=True,
):
super(NNConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
if aggregator_type == 'sum':
if aggregator_type == "sum":
self.reducer = fn.sum
elif aggregator_type == 'mean':
elif aggregator_type == "mean":
self.reducer = fn.mean
elif aggregator_type == 'max':
elif aggregator_type == "max":
self.reducer = fn.max
else:
raise KeyError('Aggregator type {} not recognized: '.format(aggregator_type))
raise KeyError(
"Aggregator type {} not recognized: ".format(aggregator_type)
)
self._aggre_type = aggregator_type
with self.name_scope():
......@@ -114,17 +119,20 @@ class NNConv(nn.Block):
if residual:
if self._in_dst_feats != out_feats:
self.res_fc = nn.Dense(
out_feats, in_units=self._in_dst_feats,
use_bias=False, weight_initializer=mx.init.Xavier())
out_feats,
in_units=self._in_dst_feats,
use_bias=False,
weight_initializer=mx.init.Xavier(),
)
else:
self.res_fc = Identity()
else:
self.res_fc = None
if bias:
self.bias = self.params.get('bias',
shape=(out_feats,),
init=mx.init.Zero())
self.bias = self.params.get(
"bias", shape=(out_feats,), init=mx.init.Zero()
)
else:
self.bias = None
......@@ -153,12 +161,16 @@ class NNConv(nn.Block):
feat_src, feat_dst = expand_as_pair(feat, graph)
# (n, d_in, 1)
graph.srcdata['h'] = feat_src.expand_dims(-1)
graph.srcdata["h"] = feat_src.expand_dims(-1)
# (n, d_in, d_out)
graph.edata['w'] = self.edge_nn(efeat).reshape(-1, self._in_src_feats, self._out_feats)
graph.edata["w"] = self.edge_nn(efeat).reshape(
-1, self._in_src_feats, self._out_feats
)
# (n, d_in, d_out)
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh'))
rst = graph.dstdata.pop('neigh').sum(axis=1) # (n, d_out)
graph.update_all(
fn.u_mul_e("h", "w", "m"), self.reducer("m", "neigh")
)
rst = graph.dstdata.pop("neigh").sum(axis=1) # (n, d_out)
# residual connection
if self.res_fc is not None:
rst = rst + self.res_fc(feat_dst)
......
"""MXNet Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape
from ....utils import check_eq_shape, expand_as_pair
class SAGEConv(nn.Block):
r"""GraphSAGE layer from `Inductive Representation Learning on
......@@ -89,20 +91,25 @@ class SAGEConv(nn.Block):
[-1.0509381 2.2239418 ]]
<NDArray 4x2 @cpu(0)>
"""
def __init__(self,
in_feats,
out_feats,
aggregator_type='mean',
feat_drop=0.,
bias=True,
norm=None,
activation=None):
def __init__(
self,
in_feats,
out_feats,
aggregator_type="mean",
feat_drop=0.0,
bias=True,
norm=None,
activation=None,
):
super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'}
valid_aggre_types = {"mean", "gcn", "pool", "lstm"}
if aggregator_type not in valid_aggre_types:
raise DGLError(
'Invalid aggregator_type. Must be one of {}. '
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type)
"Invalid aggregator_type. Must be one of {}. "
"But got {!r} instead.".format(
valid_aggre_types, aggregator_type
)
)
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
......@@ -112,19 +119,28 @@ class SAGEConv(nn.Block):
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
if aggregator_type == 'pool':
self.fc_pool = nn.Dense(self._in_src_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_src_feats)
if aggregator_type == 'lstm':
if aggregator_type == "pool":
self.fc_pool = nn.Dense(
self._in_src_feats,
use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_src_feats,
)
if aggregator_type == "lstm":
raise NotImplementedError
if aggregator_type != 'gcn':
self.fc_self = nn.Dense(out_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_dst_feats)
self.fc_neigh = nn.Dense(out_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_src_feats)
if aggregator_type != "gcn":
self.fc_self = nn.Dense(
out_feats,
use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_dst_feats,
)
self.fc_neigh = nn.Dense(
out_feats,
use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=self._in_src_feats,
)
def forward(self, graph, feat):
r"""Compute GraphSAGE layer.
......@@ -153,39 +169,47 @@ class SAGEConv(nn.Block):
else:
feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_self = feat_dst
# Handle the case of graphs without edges
if graph.number_of_edges() == 0:
dst_neigh = mx.nd.zeros((graph.number_of_dst_nodes(), self._in_src_feats))
dst_neigh = mx.nd.zeros(
(graph.number_of_dst_nodes(), self._in_src_feats)
)
dst_neigh = dst_neigh.as_in_context(feat_dst.context)
graph.dstdata['neigh'] = dst_neigh
graph.dstdata["neigh"] = dst_neigh
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
if self._aggre_type == "mean":
graph.srcdata["h"] = feat_src
graph.update_all(fn.copy_u("h", "m"), fn.mean("m", "neigh"))
h_neigh = graph.dstdata["neigh"]
elif self._aggre_type == "gcn":
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
graph.srcdata["h"] = feat_src
graph.dstdata["h"] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "neigh"))
# divide in degrees
degs = graph.in_degrees().astype(feat_dst.dtype)
degs = degs.as_in_context(feat_dst.context)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.expand_dims(-1) + 1)
elif self._aggre_type == 'pool':
graph.srcdata['h'] = nd.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm':
h_neigh = (graph.dstdata["neigh"] + graph.dstdata["h"]) / (
degs.expand_dims(-1) + 1
)
elif self._aggre_type == "pool":
graph.srcdata["h"] = nd.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u("h", "m"), fn.max("m", "neigh"))
h_neigh = graph.dstdata["neigh"]
elif self._aggre_type == "lstm":
raise NotImplementedError
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
raise KeyError(
"Aggregator type {} not recognized.".format(
self._aggre_type
)
)
if self._aggre_type == 'gcn':
if self._aggre_type == "gcn":
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
......
......@@ -60,12 +60,8 @@ class TAGConv(gluon.Block):
[ 0.32964635 -0.7669234 ]]
<NDArray 6x2 @cpu(0)>
"""
def __init__(self,
in_feats,
out_feats,
k=2,
bias=True,
activation=None):
def __init__(self, in_feats, out_feats, k=2, bias=True, activation=None):
super(TAGConv, self).__init__()
self.out_feats = out_feats
self.k = k
......@@ -74,11 +70,14 @@ class TAGConv(gluon.Block):
self.in_feats = in_feats
self.lin = self.params.get(
'weight', shape=(self.in_feats * (self.k + 1), self.out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
"weight",
shape=(self.in_feats * (self.k + 1), self.out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
if self.bias:
self.h_bias = self.params.get('bias', shape=(out_feats,),
init=mx.init.Zero())
self.h_bias = self.params.get(
"bias", shape=(out_feats,), init=mx.init.Zero()
)
def forward(self, graph, feat):
r"""
......@@ -102,21 +101,24 @@ class TAGConv(gluon.Block):
is size of output feature.
"""
with graph.local_scope():
assert graph.is_homogeneous, 'Graph is not homogeneous'
assert graph.is_homogeneous, "Graph is not homogeneous"
degs = graph.in_degrees().astype('float32')
norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5)
degs = graph.in_degrees().astype("float32")
norm = mx.nd.power(
mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5
)
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context)
rst = feat
for _ in range(self.k):
rst = rst * norm
graph.ndata['h'] = rst
graph.ndata["h"] = rst
graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.ndata['h']
graph.update_all(
fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")
)
rst = graph.ndata["h"]
rst = rst * norm
feat = mx.nd.concat(feat, rst, dim=-1)
......
......@@ -7,7 +7,13 @@ from .glob import *
from .softmax import *
from .factory import *
from .hetero import *
from .utils import Sequential, WeightBasis, JumpingKnowledge, LabelPropagation, LaplacianPosEnc
from .sparse_emb import NodeEmbedding
from .utils import (
JumpingKnowledge,
LabelPropagation,
LaplacianPosEnc,
Sequential,
WeightBasis,
)
from .network_emb import *
from .graph_transformer import *
......@@ -4,8 +4,7 @@ import torch as th
import torch.nn.functional as F
from torch import nn
from .... import broadcast_nodes
from .... import function as fn
from .... import broadcast_nodes, function as fn
from ....base import dgl_warning
......@@ -101,10 +100,9 @@ class ChebConv(nn.Module):
return graph.ndata.pop("h") * D_invsqrt
with graph.local_scope():
D_invsqrt = (
th.pow(graph.in_degrees().to(feat).clamp(min=1), -0.5)
.unsqueeze(-1)
)
D_invsqrt = th.pow(
graph.in_degrees().to(feat).clamp(min=1), -0.5
).unsqueeze(-1)
if lambda_max is None:
dgl_warning(
......
......@@ -5,7 +5,7 @@ from functools import partial
import torch
import torch.nn as nn
from .pnaconv import AGGREGATORS, SCALERS, PNAConv, PNAConvTower
from .pnaconv import AGGREGATORS, PNAConv, PNAConvTower, SCALERS
def aggregate_dir_av(h, eig_s, eig_d, eig_idx):
......
......@@ -3,9 +3,9 @@
from torch import nn
from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
from ....utils import expand_as_pair
from ...functional import edge_softmax
class DotGatConv(nn.Module):
......@@ -118,11 +118,10 @@ class DotGatConv(nn.Module):
[-0.5945, -0.4801],
[ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>)
"""
def __init__(self,
in_feats,
out_feats,
num_heads,
allow_zero_in_degree=False):
def __init__(
self, in_feats, out_feats, num_heads, allow_zero_in_degree=False
):
super(DotGatConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
......@@ -130,10 +129,22 @@ class DotGatConv(nn.Module):
self._num_heads = num_heads
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False)
self.fc_dst = nn.Linear(self._in_dst_feats, self._out_feats*self._num_heads, bias=False)
self.fc_src = nn.Linear(
self._in_src_feats,
self._out_feats * self._num_heads,
bias=False,
)
self.fc_dst = nn.Linear(
self._in_dst_feats,
self._out_feats * self._num_heads,
bias=False,
)
else:
self.fc = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False)
self.fc = nn.Linear(
self._in_src_feats,
self._out_feats * self._num_heads,
bias=False,
)
def forward(self, graph, feat, get_attention=False):
r"""
......@@ -175,45 +186,57 @@ class DotGatConv(nn.Module):
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
# check if feat is a tuple
if isinstance(feat, tuple):
h_src = feat[0]
h_dst = feat[1]
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
feat_src = self.fc_src(h_src).view(
-1, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst).view(
-1, self._num_heads, self._out_feats
)
else:
h_src = feat
feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
feat_src = feat_dst = self.fc(h_src).view(
-1, self._num_heads, self._out_feats
)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
feat_dst = feat_src[: graph.number_of_dst_nodes()]
# Assign features to nodes
graph.srcdata.update({'ft': feat_src})
graph.dstdata.update({'ft': feat_dst})
graph.srcdata.update({"ft": feat_src})
graph.dstdata.update({"ft": feat_dst})
# Step 1. dot product
graph.apply_edges(fn.u_dot_v('ft', 'ft', 'a'))
graph.apply_edges(fn.u_dot_v("ft", "ft", "a"))
# Step 2. edge softmax to compute attention scores
graph.edata['sa'] = edge_softmax(graph, graph.edata['a'] / self._out_feats**0.5)
graph.edata["sa"] = edge_softmax(
graph, graph.edata["a"] / self._out_feats**0.5
)
# Step 3. Broadcast softmax value to each edge, and aggregate dst node
graph.update_all(fn.u_mul_e('ft', 'sa', 'attn'), fn.sum('attn', 'agg_u'))
graph.update_all(
fn.u_mul_e("ft", "sa", "attn"), fn.sum("attn", "agg_u")
)
# output results to the destination nodes
rst = graph.dstdata['agg_u']
rst = graph.dstdata["agg_u"]
if get_attention:
return rst, graph.edata['sa']
return rst, graph.edata["sa"]
else:
return rst
......@@ -2,8 +2,9 @@
# pylint: disable= no-member, arguments-differ, invalid-name
from torch import nn
from ....base import DGLError
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair
......@@ -92,11 +93,10 @@ class EdgeConv(nn.Module):
[ 0.2101, 1.3466],
[ 0.2342, -0.9868]], grad_fn=<CopyReduceBackward>)
"""
def __init__(self,
in_feat,
out_feat,
batch_norm=False,
allow_zero_in_degree=False):
def __init__(
self, in_feat, out_feat, batch_norm=False, allow_zero_in_degree=False
):
super(EdgeConv, self).__init__()
self.batch_norm = batch_norm
self._allow_zero_in_degree = allow_zero_in_degree
......@@ -155,26 +155,28 @@ class EdgeConv(nn.Module):
with g.local_scope():
if not self._allow_zero_in_degree:
if (g.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
h_src, h_dst = expand_as_pair(feat, g)
g.srcdata['x'] = h_src
g.dstdata['x'] = h_dst
g.apply_edges(fn.v_sub_u('x', 'x', 'theta'))
g.edata['theta'] = self.theta(g.edata['theta'])
g.dstdata['phi'] = self.phi(g.dstdata['x'])
g.srcdata["x"] = h_src
g.dstdata["x"] = h_dst
g.apply_edges(fn.v_sub_u("x", "x", "theta"))
g.edata["theta"] = self.theta(g.edata["theta"])
g.dstdata["phi"] = self.phi(g.dstdata["x"])
if not self.batch_norm:
g.update_all(fn.e_add_v('theta', 'phi', 'e'), fn.max('e', 'x'))
g.update_all(fn.e_add_v("theta", "phi", "e"), fn.max("e", "x"))
else:
g.apply_edges(fn.e_add_v('theta', 'phi', 'e'))
g.apply_edges(fn.e_add_v("theta", "phi", "e"))
# Although the official implementation includes a per-edge
# batch norm within EdgeConv, I choose to replace it with a
# global batch norm for a number of reasons:
......@@ -194,6 +196,6 @@ class EdgeConv(nn.Module):
# In this case, the learned statistics of each position
# by batch norm is not as meaningful as those learned from
# images.
g.edata['e'] = self.bn(g.edata['e'])
g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x'))
return g.dstdata['x']
g.edata["e"] = self.bn(g.edata["e"])
g.update_all(fn.copy_e("e", "e"), fn.max("e", "x"))
return g.dstdata["x"]
......@@ -5,9 +5,10 @@ from torch import nn
from torch.nn import init
from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
from ....utils import expand_as_pair
from ...functional import edge_softmax
# pylint: enable=W0235
class EGATConv(nn.Module):
......@@ -94,47 +95,63 @@ class EGATConv(nn.Module):
>>> new_node_feats.shape, new_edge_feats.shape, attentions.shape
(torch.Size([4, 3, 10]), torch.Size([5, 3, 5]), torch.Size([5, 3, 1]))
"""
def __init__(self,
in_node_feats,
in_edge_feats,
out_node_feats,
out_edge_feats,
num_heads,
bias=True):
def __init__(
self,
in_node_feats,
in_edge_feats,
out_node_feats,
out_edge_feats,
num_heads,
bias=True,
):
super().__init__()
self._num_heads = num_heads
self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(in_node_feats)
self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(
in_node_feats
)
self._out_node_feats = out_node_feats
self._out_edge_feats = out_edge_feats
if isinstance(in_node_feats, tuple):
self.fc_node_src = nn.Linear(
self._in_src_node_feats, out_node_feats * num_heads, bias=False)
self._in_src_node_feats, out_node_feats * num_heads, bias=False
)
self.fc_ni = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
self._in_src_node_feats, out_edge_feats * num_heads, bias=False
)
self.fc_nj = nn.Linear(
self._in_dst_node_feats, out_edge_feats*num_heads, bias=False)
self._in_dst_node_feats, out_edge_feats * num_heads, bias=False
)
else:
self.fc_node_src = nn.Linear(
self._in_src_node_feats, out_node_feats * num_heads, bias=False)
self._in_src_node_feats, out_node_feats * num_heads, bias=False
)
self.fc_ni = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
self._in_src_node_feats, out_edge_feats * num_heads, bias=False
)
self.fc_nj = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
self.fc_fij = nn.Linear(in_edge_feats, out_edge_feats*num_heads, bias=False)
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_edge_feats)))
self._in_src_node_feats, out_edge_feats * num_heads, bias=False
)
self.fc_fij = nn.Linear(
in_edge_feats, out_edge_feats * num_heads, bias=False
)
self.attn = nn.Parameter(
th.FloatTensor(size=(1, num_heads, out_edge_feats))
)
if bias:
self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_edge_feats,)))
self.bias = nn.Parameter(
th.FloatTensor(size=(num_heads * out_edge_feats,))
)
else:
self.register_buffer('bias', None)
self.register_buffer("bias", None)
self.reset_parameters()
def reset_parameters(self):
"""
Reinitialize learnable parameters.
"""
gain = init.calculate_gain('relu')
gain = init.calculate_gain("relu")
init.xavier_normal_(self.fc_node_src.weight, gain=gain)
init.xavier_normal_(self.fc_ni.weight, gain=gain)
init.xavier_normal_(self.fc_fij.weight, gain=gain)
......@@ -183,13 +200,15 @@ class EGATConv(nn.Module):
with graph.local_scope():
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue."
)
# calc edge attention
# same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats
......@@ -203,27 +222,31 @@ class EGATConv(nn.Module):
f_nj = self.fc_nj(nfeats_dst)
f_fij = self.fc_fij(efeats)
graph.srcdata.update({'f_ni': f_ni})
graph.dstdata.update({'f_nj': f_nj})
graph.srcdata.update({"f_ni": f_ni})
graph.dstdata.update({"f_nj": f_nj})
# add ni, nj factors
graph.apply_edges(fn.u_add_v('f_ni', 'f_nj', 'f_tmp'))
graph.apply_edges(fn.u_add_v("f_ni", "f_nj", "f_tmp"))
# add fij to node factor
f_out = graph.edata.pop('f_tmp') + f_fij
f_out = graph.edata.pop("f_tmp") + f_fij
if self.bias is not None:
f_out = f_out + self.bias
f_out = nn.functional.leaky_relu(f_out)
f_out = f_out.view(-1, self._num_heads, self._out_edge_feats)
# compute attention factor
e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)
graph.edata['a'] = edge_softmax(graph, e)
graph.srcdata['h_out'] = self.fc_node_src(nfeats_src).view(-1, self._num_heads,
self._out_node_feats)
graph.edata["a"] = edge_softmax(graph, e)
graph.srcdata["h_out"] = self.fc_node_src(nfeats_src).view(
-1, self._num_heads, self._out_node_feats
)
# calc weighted sum
graph.update_all(fn.u_mul_e('h_out', 'a', 'm'),
fn.sum('m', 'h_out'))
graph.update_all(
fn.u_mul_e("h_out", "a", "m"), fn.sum("m", "h_out")
)
h_out = graph.dstdata['h_out'].view(-1, self._num_heads, self._out_node_feats)
h_out = graph.dstdata["h_out"].view(
-1, self._num_heads, self._out_node_feats
)
if get_attention:
return h_out, f_out, graph.edata.pop('a')
return h_out, f_out, graph.edata.pop("a")
else:
return h_out, f_out
......@@ -4,10 +4,11 @@ import torch as th
from torch import nn
from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
from ..utils import Identity
from ....utils import expand_as_pair
from ...functional import edge_softmax
from ..utils import Identity
# pylint: enable=W0235
class GATConv(nn.Module):
......@@ -130,17 +131,20 @@ class GATConv(nn.Module):
[-0.5945, -0.4801],
[ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>)
"""
def __init__(self,
in_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
bias=True):
def __init__(
self,
in_feats,
out_feats,
num_heads,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
bias=True,
):
super(GATConv, self).__init__()
self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
......@@ -148,29 +152,39 @@ class GATConv(nn.Module):
self._allow_zero_in_degree = allow_zero_in_degree
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)
self._in_src_feats, out_feats * num_heads, bias=False
)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=False)
self._in_dst_feats, out_feats * num_heads, bias=False
)
else:
self.fc = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)
self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self._in_src_feats, out_feats * num_heads, bias=False
)
self.attn_l = nn.Parameter(
th.FloatTensor(size=(1, num_heads, out_feats))
)
self.attn_r = nn.Parameter(
th.FloatTensor(size=(1, num_heads, out_feats))
)
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
if bias:
self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_feats,)))
self.bias = nn.Parameter(
th.FloatTensor(size=(num_heads * out_feats,))
)
else:
self.register_buffer('bias', None)
self.register_buffer("bias", None)
if residual:
if self._in_dst_feats != out_feats * num_heads:
self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=False)
self._in_dst_feats, num_heads * out_feats, bias=False
)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
self.register_buffer("res_fc", None)
self.reset_parameters()
self.activation = activation
......@@ -186,8 +200,8 @@ class GATConv(nn.Module):
The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The attention weights are using xavier initialization method.
"""
gain = nn.init.calculate_gain('relu')
if hasattr(self, 'fc'):
gain = nn.init.calculate_gain("relu")
if hasattr(self, "fc"):
nn.init.xavier_normal_(self.fc.weight, gain=gain)
else:
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
......@@ -251,40 +265,49 @@ class GATConv(nn.Module):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple):
src_prefix_shape = feat[0].shape[:-1]
dst_prefix_shape = feat[1].shape[:-1]
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
if not hasattr(self, "fc_src"):
feat_src = self.fc(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats)
*src_prefix_shape, self._num_heads, self._out_feats
)
feat_dst = self.fc(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats)
*dst_prefix_shape, self._num_heads, self._out_feats
)
else:
feat_src = self.fc_src(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats)
*src_prefix_shape, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats)
*dst_prefix_shape, self._num_heads, self._out_feats
)
else:
src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats)
*src_prefix_shape, self._num_heads, self._out_feats
)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_dst = h_dst[: graph.number_of_dst_nodes()]
dst_prefix_shape = (
graph.number_of_dst_nodes(),
) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
......@@ -297,31 +320,35 @@ class GATConv(nn.Module):
# which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))
graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.dstdata['ft']
graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"]
# residual
if self.res_fc is not None:
# Use -1 rather than self._num_heads to handle broadcasting
resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self._out_feats)
resval = self.res_fc(h_dst).view(
*dst_prefix_shape, -1, self._out_feats
)
rst = rst + resval
# bias
if self.bias is not None:
rst = rst + self.bias.view(
*((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats)
*((1,) * len(dst_prefix_shape)),
self._num_heads,
self._out_feats
)
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
return rst, graph.edata["a"]
else:
return rst
......@@ -85,24 +85,28 @@ class GINConv(nn.Module):
[2.5011, 0.0000, 0.0089, 2.0541, 0.8262, 0.0000, 0.0000, 0.1371, 0.0000,
0.0000]], grad_fn=<ReluBackward0>)
"""
def __init__(self,
apply_func=None,
aggregator_type='sum',
init_eps=0,
learn_eps=False,
activation=None):
def __init__(
self,
apply_func=None,
aggregator_type="sum",
init_eps=0,
learn_eps=False,
activation=None,
):
super(GINConv, self).__init__()
self.apply_func = apply_func
self._aggregator_type = aggregator_type
self.activation = activation
if aggregator_type not in ('sum', 'max', 'mean'):
if aggregator_type not in ("sum", "max", "mean"):
raise KeyError(
'Aggregator type {} not recognized.'.format(aggregator_type))
"Aggregator type {} not recognized.".format(aggregator_type)
)
# to specify whether eps is trainable or not.
if learn_eps:
self.eps = th.nn.Parameter(th.FloatTensor([init_eps]))
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))
self.register_buffer("eps", th.FloatTensor([init_eps]))
def forward(self, graph, feat, edge_weight=None):
r"""
......@@ -136,16 +140,16 @@ class GINConv(nn.Module):
"""
_reducer = getattr(fn, self._aggregator_type)
with graph.local_scope():
aggregate_fn = fn.copy_u('h', 'm')
aggregate_fn = fn.copy_u("h", "m")
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
graph.edata["_edge_weight"] = edge_weight
aggregate_fn = fn.u_mul_e("h", "_edge_weight", "m")
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, _reducer('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
graph.srcdata["h"] = feat_src
graph.update_all(aggregate_fn, _reducer("m", "neigh"))
rst = (1 + self.eps) * feat_dst + graph.dstdata["neigh"]
if self.apply_func is not None:
rst = self.apply_func(rst)
# activation
......
......@@ -6,8 +6,8 @@ from torch.nn import init
from .... import function as fn
from ....base import DGLError
from ..utils import Identity
from ....utils import expand_as_pair
from ..utils import Identity
class GMMConv(nn.Module):
......@@ -103,45 +103,54 @@ class GMMConv(nn.Module):
[-0.1377, -0.1943],
[-0.1107, -0.1559]], grad_fn=<AddBackward0>)
"""
def __init__(self,
in_feats,
out_feats,
dim,
n_kernels,
aggregator_type='sum',
residual=False,
bias=True,
allow_zero_in_degree=False):
def __init__(
self,
in_feats,
out_feats,
dim,
n_kernels,
aggregator_type="sum",
residual=False,
bias=True,
allow_zero_in_degree=False,
):
super(GMMConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._dim = dim
self._n_kernels = n_kernels
self._allow_zero_in_degree = allow_zero_in_degree
if aggregator_type == 'sum':
if aggregator_type == "sum":
self._reducer = fn.sum
elif aggregator_type == 'mean':
elif aggregator_type == "mean":
self._reducer = fn.mean
elif aggregator_type == 'max':
elif aggregator_type == "max":
self._reducer = fn.max
else:
raise KeyError("Aggregator type {} not recognized.".format(aggregator_type))
raise KeyError(
"Aggregator type {} not recognized.".format(aggregator_type)
)
self.mu = nn.Parameter(th.Tensor(n_kernels, dim))
self.inv_sigma = nn.Parameter(th.Tensor(n_kernels, dim))
self.fc = nn.Linear(self._in_src_feats, n_kernels * out_feats, bias=False)
self.fc = nn.Linear(
self._in_src_feats, n_kernels * out_feats, bias=False
)
if residual:
if self._in_dst_feats != out_feats:
self.res_fc = nn.Linear(self._in_dst_feats, out_feats, bias=False)
self.res_fc = nn.Linear(
self._in_dst_feats, out_feats, bias=False
)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
self.register_buffer("res_fc", None)
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.register_buffer("bias", None)
self.reset_parameters()
def reset_parameters(self):
......@@ -158,7 +167,7 @@ class GMMConv(nn.Module):
The mu weight is initialized using normal distribution and
inv_sigma is initialized with constant value 1.0.
"""
gain = init.calculate_gain('relu')
gain = init.calculate_gain("relu")
init.xavier_normal_(self.fc.weight, gain=gain)
if isinstance(self.res_fc, nn.Linear):
init.xavier_normal_(self.res_fc.weight, gain=gain)
......@@ -218,27 +227,38 @@ class GMMConv(nn.Module):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = self.fc(feat_src).view(-1, self._n_kernels, self._out_feats)
graph.srcdata["h"] = self.fc(feat_src).view(
-1, self._n_kernels, self._out_feats
)
E = graph.number_of_edges()
# compute gaussian weight
gaussian = -0.5 * ((pseudo.view(E, 1, self._dim) -
self.mu.view(1, self._n_kernels, self._dim)) ** 2)
gaussian = gaussian * (self.inv_sigma.view(1, self._n_kernels, self._dim) ** 2)
gaussian = th.exp(gaussian.sum(dim=-1, keepdim=True)) # (E, K, 1)
graph.edata['w'] = gaussian
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h'))
rst = graph.dstdata['h'].sum(1)
gaussian = -0.5 * (
(
pseudo.view(E, 1, self._dim)
- self.mu.view(1, self._n_kernels, self._dim)
)
** 2
)
gaussian = gaussian * (
self.inv_sigma.view(1, self._n_kernels, self._dim) ** 2
)
gaussian = th.exp(gaussian.sum(dim=-1, keepdim=True)) # (E, K, 1)
graph.edata["w"] = gaussian
graph.update_all(fn.u_mul_e("h", "w", "m"), self._reducer("m", "h"))
rst = graph.dstdata["h"].sum(1)
# residual connection
if self.res_fc is not None:
rst = rst + self.res_fc(feat_dst)
......
......@@ -6,10 +6,11 @@ from torch.nn import init
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair
from ....transforms import reverse
from ....convert import block_to_graph
from ....heterograph import DGLBlock
from ....transforms import reverse
from ....utils import expand_as_pair
class EdgeWeightNorm(nn.Module):
r"""This module normalizes positive scalar edge weights on a graph
......@@ -59,7 +60,8 @@ class EdgeWeightNorm(nn.Module):
[-1.3658, -0.8674],
[-0.8323, -0.5286]], grad_fn=<AddBackward0>)
"""
def __init__(self, norm='both', eps=0.):
def __init__(self, norm="both", eps=0.0):
super(EdgeWeightNorm, self).__init__()
self._norm = norm
self._eps = eps
......@@ -99,42 +101,57 @@ class EdgeWeightNorm(nn.Module):
if isinstance(graph, DGLBlock):
graph = block_to_graph(graph)
if len(edge_weight.shape) > 1:
raise DGLError('Currently the normalization is only defined '
'on scalar edge weight. Please customize the '
'normalization for your high-dimensional weights.')
if self._norm == 'both' and th.any(edge_weight <= 0).item():
raise DGLError('Non-positive edge weight detected with `norm="both"`. '
'This leads to square root of zero or negative values.')
raise DGLError(
"Currently the normalization is only defined "
"on scalar edge weight. Please customize the "
"normalization for your high-dimensional weights."
)
if self._norm == "both" and th.any(edge_weight <= 0).item():
raise DGLError(
'Non-positive edge weight detected with `norm="both"`. '
"This leads to square root of zero or negative values."
)
dev = graph.device
dtype = edge_weight.dtype
graph.srcdata['_src_out_w'] = th.ones(
graph.number_of_src_nodes(), dtype=dtype, device=dev)
graph.dstdata['_dst_in_w'] = th.ones(
graph.number_of_dst_nodes(), dtype=dtype, device=dev)
graph.edata['_edge_w'] = edge_weight
if self._norm == 'both':
graph.srcdata["_src_out_w"] = th.ones(
graph.number_of_src_nodes(), dtype=dtype, device=dev
)
graph.dstdata["_dst_in_w"] = th.ones(
graph.number_of_dst_nodes(), dtype=dtype, device=dev
)
graph.edata["_edge_w"] = edge_weight
if self._norm == "both":
reversed_g = reverse(graph)
reversed_g.edata['_edge_w'] = edge_weight
reversed_g.update_all(fn.copy_e('_edge_w', 'm'), fn.sum('m', 'out_weight'))
degs = reversed_g.dstdata['out_weight'] + self._eps
reversed_g.edata["_edge_w"] = edge_weight
reversed_g.update_all(
fn.copy_e("_edge_w", "m"), fn.sum("m", "out_weight")
)
degs = reversed_g.dstdata["out_weight"] + self._eps
norm = th.pow(degs, -0.5)
graph.srcdata['_src_out_w'] = norm
if self._norm != 'none':
graph.update_all(fn.copy_e('_edge_w', 'm'), fn.sum('m', 'in_weight'))
degs = graph.dstdata['in_weight'] + self._eps
if self._norm == 'both':
graph.srcdata["_src_out_w"] = norm
if self._norm != "none":
graph.update_all(
fn.copy_e("_edge_w", "m"), fn.sum("m", "in_weight")
)
degs = graph.dstdata["in_weight"] + self._eps
if self._norm == "both":
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
graph.dstdata['_dst_in_w'] = norm
graph.dstdata["_dst_in_w"] = norm
graph.apply_edges(
lambda e: {
"_norm_edge_weights": e.src["_src_out_w"]
* e.dst["_dst_in_w"]
* e.data["_edge_w"]
}
)
return graph.edata["_norm_edge_weights"]
graph.apply_edges(lambda e: {'_norm_edge_weights': e.src['_src_out_w'] * \
e.dst['_dst_in_w'] * \
e.data['_edge_w']})
return graph.edata['_norm_edge_weights']
# pylint: disable=W0235
class GraphConv(nn.Module):
......@@ -266,18 +283,23 @@ class GraphConv(nn.Module):
[-0.5287, 0.8235],
[-0.2994, 0.6106]], grad_fn=<AddBackward0>)
"""
def __init__(self,
in_feats,
out_feats,
norm='both',
weight=True,
bias=True,
activation=None,
allow_zero_in_degree=False):
def __init__(
self,
in_feats,
out_feats,
norm="both",
weight=True,
bias=True,
activation=None,
allow_zero_in_degree=False,
):
super(GraphConv, self).__init__()
if norm not in ('none', 'both', 'right', 'left'):
raise DGLError('Invalid norm value. Must be either "none", "both", "right" or "left".'
' But got "{}".'.format(norm))
if norm not in ("none", "both", "right", "left"):
raise DGLError(
'Invalid norm value. Must be either "none", "both", "right" or "left".'
' But got "{}".'.format(norm)
)
self._in_feats = in_feats
self._out_feats = out_feats
self._norm = norm
......@@ -286,12 +308,12 @@ class GraphConv(nn.Module):
if weight:
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))
else:
self.register_parameter('weight', None)
self.register_parameter("weight", None)
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)
self.reset_parameters()
......@@ -383,26 +405,28 @@ class GraphConv(nn.Module):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
aggregate_fn = fn.copy_u('h', 'm')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
aggregate_fn = fn.copy_u("h", "m")
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
graph.edata["_edge_weight"] = edge_weight
aggregate_fn = fn.u_mul_e("h", "_edge_weight", "m")
# (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
feat_src, feat_dst = expand_as_pair(feat, graph)
if self._norm in ['left', 'both']:
if self._norm in ["left", "both"]:
degs = graph.out_degrees().to(feat_src).clamp(min=1)
if self._norm == 'both':
if self._norm == "both":
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
......@@ -412,9 +436,11 @@ class GraphConv(nn.Module):
if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
raise DGLError(
"External weight is provided while at the same time the"
" module has defined its own weight parameter. Please"
" create the module with flag weight=False."
)
else:
weight = self.weight
......@@ -422,20 +448,20 @@ class GraphConv(nn.Module):
# mult W first to reduce the feature size for aggregation.
if weight is not None:
feat_src = th.matmul(feat_src, weight)
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
graph.srcdata["h"] = feat_src
graph.update_all(aggregate_fn, fn.sum(msg="m", out="h"))
rst = graph.dstdata["h"]
else:
# aggregate first then mult W
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
graph.srcdata["h"] = feat_src
graph.update_all(aggregate_fn, fn.sum(msg="m", out="h"))
rst = graph.dstdata["h"]
if weight is not None:
rst = th.matmul(rst, weight)
if self._norm in ['right', 'both']:
if self._norm in ["right", "both"]:
degs = graph.in_degrees().to(feat_dst).clamp(min=1)
if self._norm == 'both':
if self._norm == "both":
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
......@@ -455,8 +481,8 @@ class GraphConv(nn.Module):
"""Set the extra representation of the module,
which will come into effect when printing the model.
"""
summary = 'in={_in_feats}, out={_out_feats}'
summary += ', normalization={_norm}'
if '_activation' in self.__dict__:
summary += ', activation={_activation}'
summary = "in={_in_feats}, out={_out_feats}"
summary += ", normalization={_norm}"
if "_activation" in self.__dict__:
summary += ", activation={_activation}"
return summary.format(**self.__dict__)
......@@ -144,9 +144,9 @@ class HGTConv(nn.Module):
self.presorted = presorted
if g.is_block:
x_src = x
x_dst = x[:g.num_dst_nodes()]
x_dst = x[: g.num_dst_nodes()]
srcntype = ntype
dstntype = ntype[:g.num_dst_nodes()]
dstntype = ntype[: g.num_dst_nodes()]
else:
x_src = x
x_dst = x
......@@ -155,13 +155,13 @@ class HGTConv(nn.Module):
with g.local_scope():
k = self.linear_k(x_src, srcntype, presorted).view(
-1, self.num_heads, self.head_size
)
)
q = self.linear_q(x_dst, dstntype, presorted).view(
-1, self.num_heads, self.head_size
)
)
v = self.linear_v(x_src, srcntype, presorted).view(
-1, self.num_heads, self.head_size
)
)
g.srcdata["k"] = k
g.dstdata["q"] = q
g.srcdata["v"] = v
......
......@@ -5,8 +5,8 @@ from torch import nn
from torch.nn import init
from .... import function as fn
from ..utils import Identity
from ....utils import expand_as_pair
from ..utils import Identity
class NNConv(nn.Module):
......@@ -84,37 +84,44 @@ class NNConv(nn.Module):
[ 0.1261, -0.0155],
[-0.6568, 0.5042]], grad_fn=<AddBackward0>)
"""
def __init__(self,
in_feats,
out_feats,
edge_func,
aggregator_type='mean',
residual=False,
bias=True):
def __init__(
self,
in_feats,
out_feats,
edge_func,
aggregator_type="mean",
residual=False,
bias=True,
):
super(NNConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self.edge_func = edge_func
if aggregator_type == 'sum':
if aggregator_type == "sum":
self.reducer = fn.sum
elif aggregator_type == 'mean':
elif aggregator_type == "mean":
self.reducer = fn.mean
elif aggregator_type == 'max':
elif aggregator_type == "max":
self.reducer = fn.max
else:
raise KeyError('Aggregator type {} not recognized: '.format(aggregator_type))
raise KeyError(
"Aggregator type {} not recognized: ".format(aggregator_type)
)
self._aggre_type = aggregator_type
if residual:
if self._in_dst_feats != out_feats:
self.res_fc = nn.Linear(self._in_dst_feats, out_feats, bias=False)
self.res_fc = nn.Linear(
self._in_dst_feats, out_feats, bias=False
)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
self.register_buffer("res_fc", None)
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.register_buffer("bias", None)
self.reset_parameters()
def reset_parameters(self):
......@@ -129,7 +136,7 @@ class NNConv(nn.Module):
The model parameters are initialized using Glorot uniform initialization
and the bias is initialized to be zero.
"""
gain = init.calculate_gain('relu')
gain = init.calculate_gain("relu")
if self.bias is not None:
nn.init.zeros_(self.bias)
if isinstance(self.res_fc, nn.Linear):
......@@ -161,12 +168,16 @@ class NNConv(nn.Module):
feat_src, feat_dst = expand_as_pair(feat, graph)
# (n, d_in, 1)
graph.srcdata['h'] = feat_src.unsqueeze(-1)
graph.srcdata["h"] = feat_src.unsqueeze(-1)
# (n, d_in, d_out)
graph.edata['w'] = self.edge_func(efeat).view(-1, self._in_src_feats, self._out_feats)
graph.edata["w"] = self.edge_func(efeat).view(
-1, self._in_src_feats, self._out_feats
)
# (n, d_in, d_out)
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh'))
rst = graph.dstdata['neigh'].sum(dim=1) # (n, d_out)
graph.update_all(
fn.u_mul_e("h", "w", "m"), self.reducer("m", "neigh")
)
rst = graph.dstdata["neigh"].sum(dim=1) # (n, d_out)
# residual connection
if self.res_fc is not None:
rst = rst + self.res_fc(feat_dst)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment