Unverified Commit 76bb5404 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4682)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a208e886
...@@ -4,10 +4,10 @@ import mxnet as mx ...@@ -4,10 +4,10 @@ import mxnet as mx
from mxnet.gluon import nn from mxnet.gluon import nn
from .... import function as fn from .... import function as fn
from ...functional import edge_softmax
from ..utils import normalize
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair from ....utils import expand_as_pair
from ...functional import edge_softmax
from ..utils import normalize
class AGNNConv(nn.Block): class AGNNConv(nn.Block):
...@@ -75,17 +75,19 @@ class AGNNConv(nn.Block): ...@@ -75,17 +75,19 @@ class AGNNConv(nn.Block):
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
<NDArray 6x10 @cpu(0)> <NDArray 6x10 @cpu(0)>
""" """
def __init__(self,
init_beta=1., def __init__(
learn_beta=True, self, init_beta=1.0, learn_beta=True, allow_zero_in_degree=False
allow_zero_in_degree=False): ):
super(AGNNConv, self).__init__() super(AGNNConv, self).__init__()
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
with self.name_scope(): with self.name_scope():
self.beta = self.params.get('beta', self.beta = self.params.get(
shape=(1,), "beta",
grad_req='write' if learn_beta else 'null', shape=(1,),
init=mx.init.Constant(init_beta)) grad_req="write" if learn_beta else "null",
init=mx.init.Constant(init_beta),
)
def set_allow_zero_in_degree(self, set_value): def set_allow_zero_in_degree(self, set_value):
r""" r"""
...@@ -135,25 +137,27 @@ class AGNNConv(nn.Block): ...@@ -135,25 +137,27 @@ class AGNNConv(nn.Block):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if graph.in_degrees().min() == 0: if graph.in_degrees().min() == 0:
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.srcdata['norm_h'] = normalize(feat_src, p=2, axis=-1) graph.srcdata["norm_h"] = normalize(feat_src, p=2, axis=-1)
if isinstance(feat, tuple) or graph.is_block: if isinstance(feat, tuple) or graph.is_block:
graph.dstdata['norm_h'] = normalize(feat_dst, p=2, axis=-1) graph.dstdata["norm_h"] = normalize(feat_dst, p=2, axis=-1)
# compute cosine distance # compute cosine distance
graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos')) graph.apply_edges(fn.u_dot_v("norm_h", "norm_h", "cos"))
cos = graph.edata.pop('cos') cos = graph.edata.pop("cos")
e = self.beta.data(feat_src.context) * cos e = self.beta.data(feat_src.context) * cos
graph.edata['p'] = edge_softmax(graph, e) graph.edata["p"] = edge_softmax(graph, e)
graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h')) graph.update_all(fn.u_mul_e("h", "p", "m"), fn.sum("m", "h"))
return graph.dstdata.pop('h') return graph.dstdata.pop("h")
...@@ -6,6 +6,7 @@ from mxnet.gluon import nn ...@@ -6,6 +6,7 @@ from mxnet.gluon import nn
from .... import function as fn from .... import function as fn
class APPNPConv(nn.Block): class APPNPConv(nn.Block):
r"""Approximate Personalized Propagation of Neural Predictions layer from `Predict then r"""Approximate Personalized Propagation of Neural Predictions layer from `Predict then
Propagate: Graph Neural Networks meet Personalized PageRank Propagate: Graph Neural Networks meet Personalized PageRank
...@@ -56,10 +57,8 @@ class APPNPConv(nn.Block): ...@@ -56,10 +57,8 @@ class APPNPConv(nn.Block):
0.5 0.5 0.5 0.5 ]] 0.5 0.5 0.5 0.5 ]]
<NDArray 6x10 @cpu(0)> <NDArray 6x10 @cpu(0)>
""" """
def __init__(self,
k, def __init__(self, k, alpha, edge_drop=0.0):
alpha,
edge_drop=0.):
super(APPNPConv, self).__init__() super(APPNPConv, self).__init__()
self._k = k self._k = k
self._alpha = alpha self._alpha = alpha
...@@ -88,20 +87,26 @@ class APPNPConv(nn.Block): ...@@ -88,20 +87,26 @@ class APPNPConv(nn.Block):
should be the same as input shape. should be the same as input shape.
""" """
with graph.local_scope(): with graph.local_scope():
norm = mx.nd.power(mx.nd.clip( norm = mx.nd.power(
graph.in_degrees().astype(feat.dtype), a_min=1, a_max=float("inf")), -0.5) mx.nd.clip(
graph.in_degrees().astype(feat.dtype),
a_min=1,
a_max=float("inf"),
),
-0.5,
)
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context) norm = norm.reshape(shp).as_in_context(feat.context)
feat_0 = feat feat_0 = feat
for _ in range(self._k): for _ in range(self._k):
# normalization by src node # normalization by src node
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata["h"] = feat
graph.edata['w'] = self.edge_drop( graph.edata["w"] = self.edge_drop(
nd.ones((graph.number_of_edges(), 1), ctx=feat.context)) nd.ones((graph.number_of_edges(), 1), ctx=feat.context)
graph.update_all(fn.u_mul_e('h', 'w', 'm'), )
fn.sum('m', 'h')) graph.update_all(fn.u_mul_e("h", "w", "m"), fn.sum("m", "h"))
feat = graph.ndata.pop('h') feat = graph.ndata.pop("h")
# normalization by dst node # normalization by dst node
feat = feat * norm feat = feat * norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0 feat = (1 - self._alpha) * feat + self._alpha * feat_0
......
"""MXNet Module for Chebyshev Spectral Graph Convolution layer""" """MXNet Module for Chebyshev Spectral Graph Convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import math import math
import mxnet as mx import mxnet as mx
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
from .... import broadcast_nodes
from .... import function as fn
from ....base import dgl_warning from ....base import dgl_warning
from .... import broadcast_nodes, function as fn
class ChebConv(nn.Block): class ChebConv(nn.Block):
...@@ -60,11 +62,8 @@ class ChebConv(nn.Block): ...@@ -60,11 +62,8 @@ class ChebConv(nn.Block):
[ 1.7954229 0.00196505]] [ 1.7954229 0.00196505]]
<NDArray 6x2 @cpu(0)> <NDArray 6x2 @cpu(0)>
""" """
def __init__(self,
in_feats, def __init__(self, in_feats, out_feats, k, bias=True):
out_feats,
k,
bias=True):
super(ChebConv, self).__init__() super(ChebConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
...@@ -73,13 +72,19 @@ class ChebConv(nn.Block): ...@@ -73,13 +72,19 @@ class ChebConv(nn.Block):
self.fc = nn.Sequential() self.fc = nn.Sequential()
for _ in range(k): for _ in range(k):
self.fc.add( self.fc.add(
nn.Dense(out_feats, use_bias=False, nn.Dense(
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), out_feats,
in_units=in_feats) use_bias=False,
weight_initializer=mx.init.Xavier(
magnitude=math.sqrt(2.0)
),
in_units=in_feats,
)
) )
if bias: if bias:
self.bias = self.params.get('bias', shape=(out_feats,), self.bias = self.params.get(
init=mx.init.Zero()) "bias", shape=(out_feats,), init=mx.init.Zero()
)
else: else:
self.bias = None self.bias = None
...@@ -112,14 +117,17 @@ class ChebConv(nn.Block): ...@@ -112,14 +117,17 @@ class ChebConv(nn.Block):
is size of output feature. is size of output feature.
""" """
with graph.local_scope(): with graph.local_scope():
degs = graph.in_degrees().astype('float32') degs = graph.in_degrees().astype("float32")
norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5) norm = mx.nd.power(
mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5
)
norm = norm.expand_dims(-1).as_in_context(feat.context) norm = norm.expand_dims(-1).as_in_context(feat.context)
if lambda_max is None: if lambda_max is None:
dgl_warning( dgl_warning(
"lambda_max is not provided, using default value of 2. " "lambda_max is not provided, using default value of 2. "
"Please use dgl.laplacian_lambda_max to compute the eigenvalues.") "Please use dgl.laplacian_lambda_max to compute the eigenvalues."
)
lambda_max = [2] * graph.batch_size lambda_max = [2] * graph.batch_size
if isinstance(lambda_max, list): if isinstance(lambda_max, list):
...@@ -133,23 +141,25 @@ class ChebConv(nn.Block): ...@@ -133,23 +141,25 @@ class ChebConv(nn.Block):
rst = self.fc[0](Tx_0) rst = self.fc[0](Tx_0)
# T1(X) # T1(X)
if self._k > 1: if self._k > 1:
graph.ndata['h'] = Tx_0 * norm graph.ndata["h"] = Tx_0 * norm
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
h = graph.ndata.pop('h') * norm h = graph.ndata.pop("h") * norm
# Λ = 2 * (I - D ^ -1/2 A D ^ -1/2) / lambda_max - I # Λ = 2 * (I - D ^ -1/2 A D ^ -1/2) / lambda_max - I
# = - 2(D ^ -1/2 A D ^ -1/2) / lambda_max + (2 / lambda_max - 1) I # = - 2(D ^ -1/2 A D ^ -1/2) / lambda_max + (2 / lambda_max - 1) I
Tx_1 = -2. * h / lambda_max + Tx_0 * (2. / lambda_max - 1) Tx_1 = -2.0 * h / lambda_max + Tx_0 * (2.0 / lambda_max - 1)
rst = rst + self.fc[1](Tx_1) rst = rst + self.fc[1](Tx_1)
# Ti(x), i = 2...k # Ti(x), i = 2...k
for i in range(2, self._k): for i in range(2, self._k):
graph.ndata['h'] = Tx_1 * norm graph.ndata["h"] = Tx_1 * norm
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
h = graph.ndata.pop('h') * norm h = graph.ndata.pop("h") * norm
# Tx_k = 2 * Λ * Tx_(k-1) - Tx_(k-2) # Tx_k = 2 * Λ * Tx_(k-1) - Tx_(k-2)
# = - 4(D ^ -1/2 A D ^ -1/2) / lambda_max Tx_(k-1) + # = - 4(D ^ -1/2 A D ^ -1/2) / lambda_max Tx_(k-1) +
# (4 / lambda_max - 2) Tx_(k-1) - # (4 / lambda_max - 2) Tx_(k-1) -
# Tx_(k-2) # Tx_(k-2)
Tx_2 = -4. * h / lambda_max + Tx_1 * (4. / lambda_max - 2) - Tx_0 Tx_2 = (
-4.0 * h / lambda_max + Tx_1 * (4.0 / lambda_max - 2) - Tx_0
)
rst = rst + self.fc[i](Tx_2) rst = rst + self.fc[i](Tx_2)
Tx_1, Tx_0 = Tx_2, Tx_1 Tx_1, Tx_0 = Tx_2, Tx_1
# add bias # add bias
......
"""MXNet Module for DenseChebConv""" """MXNet Module for DenseChebConv"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import math import math
import mxnet as mx import mxnet as mx
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
...@@ -29,11 +30,8 @@ class DenseChebConv(nn.Block): ...@@ -29,11 +30,8 @@ class DenseChebConv(nn.Block):
-------- --------
`ChebConv <https://docs.dgl.ai/api/python/nn.pytorch.html#chebconv>`__ `ChebConv <https://docs.dgl.ai/api/python/nn.pytorch.html#chebconv>`__
""" """
def __init__(self,
in_feats, def __init__(self, in_feats, out_feats, k, bias=True):
out_feats,
k,
bias=True):
super(DenseChebConv, self).__init__() super(DenseChebConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
...@@ -42,12 +40,19 @@ class DenseChebConv(nn.Block): ...@@ -42,12 +40,19 @@ class DenseChebConv(nn.Block):
self.fc = nn.Sequential() self.fc = nn.Sequential()
for _ in range(k): for _ in range(k):
self.fc.add( self.fc.add(
nn.Dense(out_feats, in_units=in_feats, use_bias=False, nn.Dense(
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0))) out_feats,
in_units=in_feats,
use_bias=False,
weight_initializer=mx.init.Xavier(
magnitude=math.sqrt(2.0)
),
)
) )
if bias: if bias:
self.bias = self.params.get('bias', shape=(out_feats,), self.bias = self.params.get(
init=mx.init.Zero()) "bias", shape=(out_feats,), init=mx.init.Zero()
)
else: else:
self.bias = None self.bias = None
...@@ -80,7 +85,7 @@ class DenseChebConv(nn.Block): ...@@ -80,7 +85,7 @@ class DenseChebConv(nn.Block):
A = adj.astype(feat.dtype).as_in_context(feat.context) A = adj.astype(feat.dtype).as_in_context(feat.context)
num_nodes = A.shape[0] num_nodes = A.shape[0]
in_degree = 1. / nd.clip(A.sum(axis=1), 1, float('inf')).sqrt() in_degree = 1.0 / nd.clip(A.sum(axis=1), 1, float("inf")).sqrt()
D_invsqrt = nd.diag(in_degree) D_invsqrt = nd.diag(in_degree)
I = nd.eye(num_nodes, ctx=A.context) I = nd.eye(num_nodes, ctx=A.context)
L = I - nd.dot(D_invsqrt, nd.dot(A, D_invsqrt)) L = I - nd.dot(D_invsqrt, nd.dot(A, D_invsqrt))
......
"""MXNet Module for DenseGraphConv""" """MXNet Module for DenseGraphConv"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import math import math
import mxnet as mx import mxnet as mx
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
...@@ -40,22 +41,24 @@ class DenseGraphConv(nn.Block): ...@@ -40,22 +41,24 @@ class DenseGraphConv(nn.Block):
-------- --------
`GraphConv <https://docs.dgl.ai/api/python/nn.pytorch.html#graphconv>`__ `GraphConv <https://docs.dgl.ai/api/python/nn.pytorch.html#graphconv>`__
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self, in_feats, out_feats, norm="both", bias=True, activation=None
norm='both', ):
bias=True,
activation=None):
super(DenseGraphConv, self).__init__() super(DenseGraphConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._norm = norm self._norm = norm
with self.name_scope(): with self.name_scope():
self.weight = self.params.get('weight', shape=(in_feats, out_feats), self.weight = self.params.get(
init=mx.init.Xavier(magnitude=math.sqrt(2.0))) "weight",
shape=(in_feats, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
if bias: if bias:
self.bias = self.params.get('bias', shape=(out_feats,), self.bias = self.params.get(
init=mx.init.Zero()) "bias", shape=(out_feats,), init=mx.init.Zero()
)
else: else:
self.bias = None self.bias = None
self._activation = activation self._activation = activation
...@@ -86,11 +89,11 @@ class DenseGraphConv(nn.Block): ...@@ -86,11 +89,11 @@ class DenseGraphConv(nn.Block):
is size of output feature. is size of output feature.
""" """
adj = adj.astype(feat.dtype).as_in_context(feat.context) adj = adj.astype(feat.dtype).as_in_context(feat.context)
src_degrees = nd.clip(adj.sum(axis=0), a_min=1, a_max=float('inf')) src_degrees = nd.clip(adj.sum(axis=0), a_min=1, a_max=float("inf"))
dst_degrees = nd.clip(adj.sum(axis=1), a_min=1, a_max=float('inf')) dst_degrees = nd.clip(adj.sum(axis=1), a_min=1, a_max=float("inf"))
feat_src = feat feat_src = feat
if self._norm == 'both': if self._norm == "both":
norm_src = nd.power(src_degrees, -0.5) norm_src = nd.power(src_degrees, -0.5)
shp_src = norm_src.shape + (1,) * (feat.ndim - 1) shp_src = norm_src.shape + (1,) * (feat.ndim - 1)
norm_src = norm_src.reshape(shp_src).as_in_context(feat.context) norm_src = norm_src.reshape(shp_src).as_in_context(feat.context)
...@@ -105,10 +108,10 @@ class DenseGraphConv(nn.Block): ...@@ -105,10 +108,10 @@ class DenseGraphConv(nn.Block):
rst = nd.dot(adj, feat_src) rst = nd.dot(adj, feat_src)
rst = nd.dot(rst, self.weight.data(feat_src.context)) rst = nd.dot(rst, self.weight.data(feat_src.context))
if self._norm != 'none': if self._norm != "none":
if self._norm == 'both': if self._norm == "both":
norm_dst = nd.power(dst_degrees, -0.5) norm_dst = nd.power(dst_degrees, -0.5)
else: # right else: # right
norm_dst = 1.0 / dst_degrees norm_dst = 1.0 / dst_degrees
shp_dst = norm_dst.shape + (1,) * (feat.ndim - 1) shp_dst = norm_dst.shape + (1,) * (feat.ndim - 1)
norm_dst = norm_dst.reshape(shp_dst).as_in_context(feat.context) norm_dst = norm_dst.reshape(shp_dst).as_in_context(feat.context)
......
"""MXNet Module for DenseGraphSAGE""" """MXNet Module for DenseGraphSAGE"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import math import math
import mxnet as mx import mxnet as mx
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
from ....utils import check_eq_shape from ....utils import check_eq_shape
...@@ -35,13 +37,16 @@ class DenseSAGEConv(nn.Block): ...@@ -35,13 +37,16 @@ class DenseSAGEConv(nn.Block):
-------- --------
`SAGEConv <https://docs.dgl.ai/api/python/nn.pytorch.html#sageconv>`__ `SAGEConv <https://docs.dgl.ai/api/python/nn.pytorch.html#sageconv>`__
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self,
feat_drop=0., in_feats,
bias=True, out_feats,
norm=None, feat_drop=0.0,
activation=None): bias=True,
norm=None,
activation=None,
):
super(DenseSAGEConv, self).__init__() super(DenseSAGEConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
...@@ -49,8 +54,12 @@ class DenseSAGEConv(nn.Block): ...@@ -49,8 +54,12 @@ class DenseSAGEConv(nn.Block):
with self.name_scope(): with self.name_scope():
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation self.activation = activation
self.fc = nn.Dense(out_feats, in_units=in_feats, use_bias=bias, self.fc = nn.Dense(
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0))) out_feats,
in_units=in_feats,
use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
def forward(self, adj, feat): def forward(self, adj, feat):
r""" r"""
......
...@@ -6,6 +6,7 @@ from mxnet.gluon import nn ...@@ -6,6 +6,7 @@ from mxnet.gluon import nn
from .... import function as fn from .... import function as fn
class GatedGraphConv(nn.Block): class GatedGraphConv(nn.Block):
r"""Gated Graph Convolution layer from `Gated Graph Sequence r"""Gated Graph Convolution layer from `Gated Graph Sequence
Neural Networks <https://arxiv.org/pdf/1511.05493.pdf>`__ Neural Networks <https://arxiv.org/pdf/1511.05493.pdf>`__
...@@ -59,26 +60,24 @@ class GatedGraphConv(nn.Block): ...@@ -59,26 +60,24 @@ class GatedGraphConv(nn.Block):
0.23958017 0.23430146 0.26431587 0.27001363]] 0.23958017 0.23430146 0.26431587 0.27001363]]
<NDArray 6x10 @cpu(0)> <NDArray 6x10 @cpu(0)>
""" """
def __init__(self,
in_feats, def __init__(self, in_feats, out_feats, n_steps, n_etypes, bias=True):
out_feats,
n_steps,
n_etypes,
bias=True):
super(GatedGraphConv, self).__init__() super(GatedGraphConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._n_steps = n_steps self._n_steps = n_steps
self._n_etypes = n_etypes self._n_etypes = n_etypes
if not bias: if not bias:
raise KeyError('MXNet do not support disabling bias in GRUCell.') raise KeyError("MXNet do not support disabling bias in GRUCell.")
with self.name_scope(): with self.name_scope():
self.linears = nn.Sequential() self.linears = nn.Sequential()
for _ in range(n_etypes): for _ in range(n_etypes):
self.linears.add( self.linears.add(
nn.Dense(out_feats, nn.Dense(
weight_initializer=mx.init.Xavier(), out_feats,
in_units=out_feats) weight_initializer=mx.init.Xavier(),
in_units=out_feats,
)
) )
self.gru = gluon.rnn.GRUCell(out_feats, input_size=out_feats) self.gru = gluon.rnn.GRUCell(out_feats, input_size=out_feats)
...@@ -104,25 +103,33 @@ class GatedGraphConv(nn.Block): ...@@ -104,25 +103,33 @@ class GatedGraphConv(nn.Block):
is the output feature size. is the output feature size.
""" """
with graph.local_scope(): with graph.local_scope():
assert graph.is_homogeneous, \ assert graph.is_homogeneous, (
"not a homogeneous graph; convert it with to_homogeneous " \ "not a homogeneous graph; convert it with to_homogeneous "
"and pass in the edge type as argument" "and pass in the edge type as argument"
zero_pad = nd.zeros((feat.shape[0], self._out_feats - feat.shape[1]), )
ctx=feat.context) zero_pad = nd.zeros(
(feat.shape[0], self._out_feats - feat.shape[1]),
ctx=feat.context,
)
feat = nd.concat(feat, zero_pad, dim=-1) feat = nd.concat(feat, zero_pad, dim=-1)
for _ in range(self._n_steps): for _ in range(self._n_steps):
graph.ndata['h'] = feat graph.ndata["h"] = feat
for i in range(self._n_etypes): for i in range(self._n_etypes):
eids = (etypes.asnumpy() == i).nonzero()[0] eids = (etypes.asnumpy() == i).nonzero()[0]
eids = nd.from_numpy(eids, zero_copy=True).as_in_context( eids = (
feat.context).astype(graph.idtype) nd.from_numpy(eids, zero_copy=True)
.as_in_context(feat.context)
.astype(graph.idtype)
)
if len(eids) > 0: if len(eids) > 0:
graph.apply_edges( graph.apply_edges(
lambda edges: {'W_e*h': self.linears[i](edges.src['h'])}, lambda edges: {
eids "W_e*h": self.linears[i](edges.src["h"])
},
eids,
) )
graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a')) graph.update_all(fn.copy_e("W_e*h", "m"), fn.sum("m", "a"))
a = graph.ndata.pop('a') a = graph.ndata.pop("a")
feat = self.gru(a, [feat])[0] feat = self.gru(a, [feat])[0]
return feat return feat
...@@ -58,27 +58,30 @@ class GINConv(nn.Block): ...@@ -58,27 +58,30 @@ class GINConv(nn.Block):
-0.02858362 -0.10365082 0.07060662 0.23041813]] -0.02858362 -0.10365082 0.07060662 0.23041813]]
<NDArray 6x10 @cpu(0)> <NDArray 6x10 @cpu(0)>
""" """
def __init__(self,
apply_func, def __init__(
aggregator_type, self, apply_func, aggregator_type, init_eps=0, learn_eps=False
init_eps=0, ):
learn_eps=False):
super(GINConv, self).__init__() super(GINConv, self).__init__()
if aggregator_type == 'sum': if aggregator_type == "sum":
self._reducer = fn.sum self._reducer = fn.sum
elif aggregator_type == 'max': elif aggregator_type == "max":
self._reducer = fn.max self._reducer = fn.max
elif aggregator_type == 'mean': elif aggregator_type == "mean":
self._reducer = fn.mean self._reducer = fn.mean
else: else:
raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type)) raise KeyError(
"Aggregator type {} not recognized.".format(aggregator_type)
)
with self.name_scope(): with self.name_scope():
self.apply_func = apply_func self.apply_func = apply_func
self.eps = self.params.get('eps', self.eps = self.params.get(
shape=(1,), "eps",
grad_req='write' if learn_eps else 'null', shape=(1,),
init=mx.init.Constant(init_eps)) grad_req="write" if learn_eps else "null",
init=mx.init.Constant(init_eps),
)
def forward(self, graph, feat): def forward(self, graph, feat):
r""" r"""
...@@ -109,9 +112,11 @@ class GINConv(nn.Block): ...@@ -109,9 +112,11 @@ class GINConv(nn.Block):
""" """
with graph.local_scope(): with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) graph.update_all(fn.copy_u("h", "m"), self._reducer("m", "neigh"))
rst = (1 + self.eps.data(feat_dst.context)) * feat_dst + graph.dstdata['neigh'] rst = (
1 + self.eps.data(feat_dst.context)
) * feat_dst + graph.dstdata["neigh"]
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
return rst return rst
"""MXNet module for RelGraphConv""" """MXNet module for RelGraphConv"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import math import math
import numpy as np
import mxnet as mx import mxnet as mx
import numpy as np
from mxnet import gluon, nd from mxnet import gluon, nd
from mxnet.gluon import nn from mxnet.gluon import nn
from .... import function as fn from .... import function as fn
from .. import utils from .. import utils
...@@ -98,57 +99,78 @@ class RelGraphConv(gluon.Block): ...@@ -98,57 +99,78 @@ class RelGraphConv(gluon.Block):
[ 0.056508 -0.00307822]] [ 0.056508 -0.00307822]]
<NDArray 6x2 @cpu(0)> <NDArray 6x2 @cpu(0)>
""" """
def __init__(self,
in_feat, def __init__(
out_feat, self,
num_rels, in_feat,
regularizer="basis", out_feat,
num_bases=None, num_rels,
bias=True, regularizer="basis",
activation=None, num_bases=None,
self_loop=True, bias=True,
low_mem=False, activation=None,
dropout=0.0, self_loop=True,
layer_norm=False): low_mem=False,
dropout=0.0,
layer_norm=False,
):
super(RelGraphConv, self).__init__() super(RelGraphConv, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.num_rels = num_rels self.num_rels = num_rels
self.regularizer = regularizer self.regularizer = regularizer
self.num_bases = num_bases self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0: if (
self.num_bases is None
or self.num_bases > self.num_rels
or self.num_bases < 0
):
self.num_bases = self.num_rels self.num_bases = self.num_rels
self.bias = bias self.bias = bias
self.activation = activation self.activation = activation
self.self_loop = self_loop self.self_loop = self_loop
assert low_mem is False, 'MXNet currently does not support low-memory implementation.' assert (
assert layer_norm is False, 'MXNet currently does not support layer norm.' low_mem is False
), "MXNet currently does not support low-memory implementation."
assert (
layer_norm is False
), "MXNet currently does not support layer norm."
if regularizer == "basis": if regularizer == "basis":
# add basis weights # add basis weights
self.weight = self.params.get( self.weight = self.params.get(
'weight', shape=(self.num_bases, self.in_feat, self.out_feat), "weight",
init=mx.init.Xavier(magnitude=math.sqrt(2.0))) shape=(self.num_bases, self.in_feat, self.out_feat),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
# linear combination coefficients # linear combination coefficients
self.w_comp = self.params.get( self.w_comp = self.params.get(
'w_comp', shape=(self.num_rels, self.num_bases), "w_comp",
init=mx.init.Xavier(magnitude=math.sqrt(2.0))) shape=(self.num_rels, self.num_bases),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
# message func # message func
self.message_func = self.basis_message_func self.message_func = self.basis_message_func
elif regularizer == "bdd": elif regularizer == "bdd":
if in_feat % num_bases != 0 or out_feat % num_bases != 0: if in_feat % num_bases != 0 or out_feat % num_bases != 0:
raise ValueError('Feature size must be a multiplier of num_bases.') raise ValueError(
"Feature size must be a multiplier of num_bases."
)
# add block diagonal weights # add block diagonal weights
self.submat_in = in_feat // self.num_bases self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases self.submat_out = out_feat // self.num_bases
# assuming in_feat and out_feat are both divisible by num_bases # assuming in_feat and out_feat are both divisible by num_bases
self.weight = self.params.get( self.weight = self.params.get(
'weight', "weight",
shape=(self.num_rels, self.num_bases * self.submat_in * self.submat_out), shape=(
init=mx.init.Xavier(magnitude=math.sqrt(2.0))) self.num_rels,
self.num_bases * self.submat_in * self.submat_out,
),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
# message func # message func
self.message_func = self.bdd_message_func self.message_func = self.bdd_message_func
else: else:
...@@ -156,46 +178,57 @@ class RelGraphConv(gluon.Block): ...@@ -156,46 +178,57 @@ class RelGraphConv(gluon.Block):
# bias # bias
if self.bias: if self.bias:
self.h_bias = self.params.get('bias', shape=(out_feat,), self.h_bias = self.params.get(
init=mx.init.Zero()) "bias", shape=(out_feat,), init=mx.init.Zero()
)
# weight for self loop # weight for self loop
if self.self_loop: if self.self_loop:
self.loop_weight = self.params.get( self.loop_weight = self.params.get(
'W_0', shape=(in_feat, out_feat), "W_0",
init=mx.init.Xavier(magnitude=math.sqrt(2.0))) shape=(in_feat, out_feat),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)),
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def basis_message_func(self, edges): def basis_message_func(self, edges):
"""Message function for basis regularizer""" """Message function for basis regularizer"""
ctx = edges.src['h'].context ctx = edges.src["h"].context
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
# generate all weights from bases # generate all weights from bases
weight = self.weight.data(ctx).reshape( weight = self.weight.data(ctx).reshape(
self.num_bases, self.in_feat * self.out_feat) self.num_bases, self.in_feat * self.out_feat
)
weight = nd.dot(self.w_comp.data(ctx), weight).reshape( weight = nd.dot(self.w_comp.data(ctx), weight).reshape(
self.num_rels, self.in_feat, self.out_feat) self.num_rels, self.in_feat, self.out_feat
)
else: else:
weight = self.weight.data(ctx) weight = self.weight.data(ctx)
msg = utils.bmm_maybe_select(edges.src['h'], weight, edges.data['type']) msg = utils.bmm_maybe_select(edges.src["h"], weight, edges.data["type"])
if 'norm' in edges.data: if "norm" in edges.data:
msg = msg * edges.data['norm'] msg = msg * edges.data["norm"]
return {'msg': msg} return {"msg": msg}
def bdd_message_func(self, edges): def bdd_message_func(self, edges):
"""Message function for block-diagonal-decomposition regularizer""" """Message function for block-diagonal-decomposition regularizer"""
ctx = edges.src['h'].context ctx = edges.src["h"].context
if edges.src['h'].dtype in (np.int32, np.int64) and len(edges.src['h'].shape) == 1: if (
raise TypeError('Block decomposition does not allow integer ID feature.') edges.src["h"].dtype in (np.int32, np.int64)
weight = self.weight.data(ctx)[edges.data['type'], :].reshape( and len(edges.src["h"].shape) == 1
-1, self.submat_in, self.submat_out) ):
node = edges.src['h'].reshape(-1, 1, self.submat_in) raise TypeError(
"Block decomposition does not allow integer ID feature."
)
weight = self.weight.data(ctx)[edges.data["type"], :].reshape(
-1, self.submat_in, self.submat_out
)
node = edges.src["h"].reshape(-1, 1, self.submat_in)
msg = nd.batch_dot(node, weight).reshape(-1, self.out_feat) msg = nd.batch_dot(node, weight).reshape(-1, self.out_feat)
if 'norm' in edges.data: if "norm" in edges.data:
msg = msg * edges.data['norm'] msg = msg * edges.data["norm"]
return {'msg': msg} return {"msg": msg}
def forward(self, g, x, etypes, norm=None): def forward(self, g, x, etypes, norm=None):
""" """
...@@ -224,22 +257,25 @@ class RelGraphConv(gluon.Block): ...@@ -224,22 +257,25 @@ class RelGraphConv(gluon.Block):
mx.ndarray.NDArray mx.ndarray.NDArray
New node features. New node features.
""" """
assert g.is_homogeneous, \ assert g.is_homogeneous, (
"not a homogeneous graph; convert it with to_homogeneous " \ "not a homogeneous graph; convert it with to_homogeneous "
"and pass in the edge type as argument" "and pass in the edge type as argument"
)
with g.local_scope(): with g.local_scope():
g.ndata['h'] = x g.ndata["h"] = x
g.edata['type'] = etypes g.edata["type"] = etypes
if norm is not None: if norm is not None:
g.edata['norm'] = norm g.edata["norm"] = norm
if self.self_loop: if self.self_loop:
loop_message = utils.matmul_maybe_select(x, self.loop_weight.data(x.context)) loop_message = utils.matmul_maybe_select(
x, self.loop_weight.data(x.context)
)
# message passing # message passing
g.update_all(self.message_func, fn.sum(msg='msg', out='h')) g.update_all(self.message_func, fn.sum(msg="msg", out="h"))
# apply bias and activation # apply bias and activation
node_repr = g.ndata['h'] node_repr = g.ndata["h"]
if self.bias: if self.bias:
node_repr = node_repr + self.h_bias.data(x.context) node_repr = node_repr + self.h_bias.data(x.context)
if self.self_loop: if self.self_loop:
......
...@@ -84,14 +84,17 @@ class SGConv(nn.Block): ...@@ -84,14 +84,17 @@ class SGConv(nn.Block):
[ 2.2644043 -0.26684904]] [ 2.2644043 -0.26684904]]
<NDArray 6x2 @cpu(0)> <NDArray 6x2 @cpu(0)>
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self,
k=1, in_feats,
cached=False, out_feats,
bias=True, k=1,
norm=None, cached=False,
allow_zero_in_degree=False): bias=True,
norm=None,
allow_zero_in_degree=False,
):
super(SGConv, self).__init__() super(SGConv, self).__init__()
self._cached = cached self._cached = cached
self._cached_h = None self._cached_h = None
...@@ -99,8 +102,12 @@ class SGConv(nn.Block): ...@@ -99,8 +102,12 @@ class SGConv(nn.Block):
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
with self.name_scope(): with self.name_scope():
self.norm = norm self.norm = norm
self.fc = nn.Dense(out_feats, in_units=in_feats, use_bias=bias, self.fc = nn.Dense(
weight_initializer=mx.init.Xavier()) out_feats,
in_units=in_feats,
use_bias=bias,
weight_initializer=mx.init.Xavier(),
)
def set_allow_zero_in_degree(self, set_value): def set_allow_zero_in_degree(self, set_value):
r""" r"""
...@@ -152,30 +159,33 @@ class SGConv(nn.Block): ...@@ -152,30 +159,33 @@ class SGConv(nn.Block):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if graph.in_degrees().min() == 0: if graph.in_degrees().min() == 0:
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if self._cached_h is not None: if self._cached_h is not None:
feat = self._cached_h feat = self._cached_h
else: else:
# compute normalization # compute normalization
degs = nd.clip(graph.in_degrees().astype(feat.dtype), 1, float('inf')) degs = nd.clip(
graph.in_degrees().astype(feat.dtype), 1, float("inf")
)
norm = nd.power(degs, -0.5).expand_dims(1) norm = nd.power(degs, -0.5).expand_dims(1)
norm = norm.as_in_context(feat.context) norm = norm.as_in_context(feat.context)
# compute (D^-1 A D)^k X # compute (D^-1 A D)^k X
for _ in range(self._k): for _ in range(self._k):
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata["h"] = feat
graph.update_all(fn.copy_u('h', 'm'), graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
fn.sum('m', 'h')) feat = graph.ndata.pop("h")
feat = graph.ndata.pop('h')
feat = feat * norm feat = feat * norm
if self.norm is not None: if self.norm is not None:
......
...@@ -3,11 +3,24 @@ ...@@ -3,11 +3,24 @@
from mxnet import gluon, nd from mxnet import gluon, nd
from mxnet.gluon import nn from mxnet.gluon import nn
from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\ from ...readout import (
softmax_nodes, topk_nodes broadcast_nodes,
max_nodes,
mean_nodes,
softmax_nodes,
sum_nodes,
topk_nodes,
)
__all__ = [
"SumPooling",
"AvgPooling",
"MaxPooling",
"SortPooling",
"GlobalAttentionPooling",
"Set2Set",
]
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'GlobalAttentionPooling', 'Set2Set']
class SumPooling(nn.Block): class SumPooling(nn.Block):
r"""Apply sum pooling over the nodes in the graph. r"""Apply sum pooling over the nodes in the graph.
...@@ -15,6 +28,7 @@ class SumPooling(nn.Block): ...@@ -15,6 +28,7 @@ class SumPooling(nn.Block):
.. math:: .. math::
r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k
""" """
def __init__(self): def __init__(self):
super(SumPooling, self).__init__() super(SumPooling, self).__init__()
...@@ -36,13 +50,13 @@ class SumPooling(nn.Block): ...@@ -36,13 +50,13 @@ class SumPooling(nn.Block):
:math:`B` refers to the batch size. :math:`B` refers to the batch size.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata["h"] = feat
readout = sum_nodes(graph, 'h') readout = sum_nodes(graph, "h")
graph.ndata.pop('h') graph.ndata.pop("h")
return readout return readout
def __repr__(self): def __repr__(self):
return 'SumPooling()' return "SumPooling()"
class AvgPooling(nn.Block): class AvgPooling(nn.Block):
...@@ -51,6 +65,7 @@ class AvgPooling(nn.Block): ...@@ -51,6 +65,7 @@ class AvgPooling(nn.Block):
.. math:: .. math::
r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k
""" """
def __init__(self): def __init__(self):
super(AvgPooling, self).__init__() super(AvgPooling, self).__init__()
...@@ -72,13 +87,13 @@ class AvgPooling(nn.Block): ...@@ -72,13 +87,13 @@ class AvgPooling(nn.Block):
:math:`B` refers to the batch size. :math:`B` refers to the batch size.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata["h"] = feat
readout = mean_nodes(graph, 'h') readout = mean_nodes(graph, "h")
graph.ndata.pop('h') graph.ndata.pop("h")
return readout return readout
def __repr__(self): def __repr__(self):
return 'AvgPooling()' return "AvgPooling()"
class MaxPooling(nn.Block): class MaxPooling(nn.Block):
...@@ -87,6 +102,7 @@ class MaxPooling(nn.Block): ...@@ -87,6 +102,7 @@ class MaxPooling(nn.Block):
.. math:: .. math::
r^{(i)} = \max_{k=1}^{N_i} \left( x^{(i)}_k \right) r^{(i)} = \max_{k=1}^{N_i} \left( x^{(i)}_k \right)
""" """
def __init__(self): def __init__(self):
super(MaxPooling, self).__init__() super(MaxPooling, self).__init__()
...@@ -108,13 +124,13 @@ class MaxPooling(nn.Block): ...@@ -108,13 +124,13 @@ class MaxPooling(nn.Block):
:math:`B` refers to the batch size. :math:`B` refers to the batch size.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata["h"] = feat
readout = max_nodes(graph, 'h') readout = max_nodes(graph, "h")
graph.ndata.pop('h') graph.ndata.pop("h")
return readout return readout
def __repr__(self): def __repr__(self):
return 'MaxPooling()' return "MaxPooling()"
class SortPooling(nn.Block): class SortPooling(nn.Block):
...@@ -126,6 +142,7 @@ class SortPooling(nn.Block): ...@@ -126,6 +142,7 @@ class SortPooling(nn.Block):
k : int k : int
The number of nodes to hold for each graph. The number of nodes to hold for each graph.
""" """
def __init__(self, k): def __init__(self, k):
super(SortPooling, self).__init__() super(SortPooling, self).__init__()
self.k = k self.k = k
...@@ -150,14 +167,15 @@ class SortPooling(nn.Block): ...@@ -150,14 +167,15 @@ class SortPooling(nn.Block):
# Sort the feature of each node in ascending order. # Sort the feature of each node in ascending order.
with graph.local_scope(): with graph.local_scope():
feat = feat.sort(axis=-1) feat = feat.sort(axis=-1)
graph.ndata['h'] = feat graph.ndata["h"] = feat
# Sort nodes according to their last features. # Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k, sortby=-1)[0].reshape( ret = topk_nodes(graph, "h", self.k, sortby=-1)[0].reshape(
-1, self.k * feat.shape[-1]) -1, self.k * feat.shape[-1]
)
return ret return ret
def __repr__(self): def __repr__(self):
return 'SortPooling(k={})'.format(self.k) return "SortPooling(k={})".format(self.k)
class GlobalAttentionPooling(nn.Block): class GlobalAttentionPooling(nn.Block):
...@@ -176,6 +194,7 @@ class GlobalAttentionPooling(nn.Block): ...@@ -176,6 +194,7 @@ class GlobalAttentionPooling(nn.Block):
A neural network applied to each feature before combining them A neural network applied to each feature before combining them
with attention scores. with attention scores.
""" """
def __init__(self, gate_nn, feat_nn=None): def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__() super(GlobalAttentionPooling, self).__init__()
with self.name_scope(): with self.name_scope():
...@@ -201,14 +220,16 @@ class GlobalAttentionPooling(nn.Block): ...@@ -201,14 +220,16 @@ class GlobalAttentionPooling(nn.Block):
""" """
with graph.local_scope(): with graph.local_scope():
gate = self.gate_nn(feat) gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis." assert (
gate.shape[-1] == 1
), "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate graph.ndata["gate"] = gate
gate = softmax_nodes(graph, 'gate') gate = softmax_nodes(graph, "gate")
graph.ndata['r'] = feat * gate graph.ndata["r"] = feat * gate
readout = sum_nodes(graph, 'r') readout = sum_nodes(graph, "r")
return readout return readout
...@@ -239,6 +260,7 @@ class Set2Set(nn.Block): ...@@ -239,6 +260,7 @@ class Set2Set(nn.Block):
n_layers : int n_layers : int
Number of recurrent layers. Number of recurrent layers.
""" """
def __init__(self, input_dim, n_iters, n_layers): def __init__(self, input_dim, n_iters, n_layers):
super(Set2Set, self).__init__() super(Set2Set, self).__init__()
self.input_dim = input_dim self.input_dim = input_dim
...@@ -247,7 +269,8 @@ class Set2Set(nn.Block): ...@@ -247,7 +269,8 @@ class Set2Set(nn.Block):
self.n_layers = n_layers self.n_layers = n_layers
with self.name_scope(): with self.name_scope():
self.lstm = gluon.rnn.LSTM( self.lstm = gluon.rnn.LSTM(
self.input_dim, num_layers=n_layers, input_size=self.output_dim) self.input_dim, num_layers=n_layers, input_size=self.output_dim
)
def forward(self, graph, feat): def forward(self, graph, feat):
r"""Compute set2set pooling. r"""Compute set2set pooling.
...@@ -269,28 +292,36 @@ class Set2Set(nn.Block): ...@@ -269,28 +292,36 @@ class Set2Set(nn.Block):
with graph.local_scope(): with graph.local_scope():
batch_size = graph.batch_size batch_size = graph.batch_size
h = (nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context), h = (
nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context)) nd.zeros(
(self.n_layers, batch_size, self.input_dim),
ctx=feat.context,
),
nd.zeros(
(self.n_layers, batch_size, self.input_dim),
ctx=feat.context,
),
)
q_star = nd.zeros((batch_size, self.output_dim), ctx=feat.context) q_star = nd.zeros((batch_size, self.output_dim), ctx=feat.context)
for _ in range(self.n_iters): for _ in range(self.n_iters):
q, h = self.lstm(q_star.expand_dims(axis=0), h) q, h = self.lstm(q_star.expand_dims(axis=0), h)
q = q.reshape((batch_size, self.input_dim)) q = q.reshape((batch_size, self.input_dim))
e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True) e = (feat * broadcast_nodes(graph, q)).sum(
graph.ndata['e'] = e axis=-1, keepdims=True
alpha = softmax_nodes(graph, 'e') )
graph.ndata['r'] = feat * alpha graph.ndata["e"] = e
readout = sum_nodes(graph, 'r') alpha = softmax_nodes(graph, "e")
graph.ndata["r"] = feat * alpha
readout = sum_nodes(graph, "r")
q_star = nd.concat(q, readout, dim=-1) q_star = nd.concat(q, readout, dim=-1)
return q_star return q_star
def __repr__(self): def __repr__(self):
summary = 'Set2Set(' summary = "Set2Set("
summary += 'in={}, out={}, ' \ summary += "in={}, out={}, " "n_iters={}, n_layers={}".format(
'n_iters={}, n_layers={}'.format(self.input_dim, self.input_dim, self.output_dim, self.n_iters, self.n_layers
self.output_dim, )
self.n_iters, summary += ")"
self.n_layers)
summary += ')'
return summary return summary
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
__all__ = ['HeteroGraphConv'] __all__ = ["HeteroGraphConv"]
class HeteroGraphConv(nn.Block): class HeteroGraphConv(nn.Block):
r"""A generic module for computing convolution on heterogeneous graphs r"""A generic module for computing convolution on heterogeneous graphs
...@@ -118,7 +119,8 @@ class HeteroGraphConv(nn.Block): ...@@ -118,7 +119,8 @@ class HeteroGraphConv(nn.Block):
mods : dict[str, nn.Module] mods : dict[str, nn.Module]
Modules associated with every edge types. Modules associated with every edge types.
""" """
def __init__(self, mods, aggregate='sum'):
def __init__(self, mods, aggregate="sum"):
super(HeteroGraphConv, self).__init__() super(HeteroGraphConv, self).__init__()
with self.name_scope(): with self.name_scope():
for name, mod in mods.items(): for name, mod in mods.items():
...@@ -127,7 +129,9 @@ class HeteroGraphConv(nn.Block): ...@@ -127,7 +129,9 @@ class HeteroGraphConv(nn.Block):
# Do not break if graph has 0-in-degree nodes. # Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph. # Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items(): for _, v in self.mods.items():
set_allow_zero_in_degree_fn = getattr(v, 'set_allow_zero_in_degree', None) set_allow_zero_in_degree_fn = getattr(
v, "set_allow_zero_in_degree", None
)
if callable(set_allow_zero_in_degree_fn): if callable(set_allow_zero_in_degree_fn):
set_allow_zero_in_degree_fn(True) set_allow_zero_in_degree_fn(True)
if isinstance(aggregate, str): if isinstance(aggregate, str):
...@@ -160,7 +164,7 @@ class HeteroGraphConv(nn.Block): ...@@ -160,7 +164,7 @@ class HeteroGraphConv(nn.Block):
mod_args = {} mod_args = {}
if mod_kwargs is None: if mod_kwargs is None:
mod_kwargs = {} mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes} outputs = {nty: [] for nty in g.dsttypes}
if isinstance(inputs, tuple): if isinstance(inputs, tuple):
src_inputs, dst_inputs = inputs src_inputs, dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
...@@ -171,7 +175,8 @@ class HeteroGraphConv(nn.Block): ...@@ -171,7 +175,8 @@ class HeteroGraphConv(nn.Block):
rel_graph, rel_graph,
(src_inputs[stype], dst_inputs[dtype]), (src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()), *mod_args.get(etype, ()),
**mod_kwargs.get(etype, {})) **mod_kwargs.get(etype, {})
)
outputs[dtype].append(dstdata) outputs[dtype].append(dstdata)
else: else:
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
...@@ -182,7 +187,8 @@ class HeteroGraphConv(nn.Block): ...@@ -182,7 +187,8 @@ class HeteroGraphConv(nn.Block):
rel_graph, rel_graph,
(inputs[stype], inputs[dtype]), (inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()), *mod_args.get(etype, ()),
**mod_kwargs.get(etype, {})) **mod_kwargs.get(etype, {})
)
outputs[dtype].append(dstdata) outputs[dtype].append(dstdata)
rsts = {} rsts = {}
for nty, alist in outputs.items(): for nty, alist in outputs.items():
...@@ -191,12 +197,13 @@ class HeteroGraphConv(nn.Block): ...@@ -191,12 +197,13 @@ class HeteroGraphConv(nn.Block):
return rsts return rsts
def __repr__(self): def __repr__(self):
summary = 'HeteroGraphConv({\n' summary = "HeteroGraphConv({\n"
for name, mod in self.mods.items(): for name, mod in self.mods.items():
summary += ' {} : {},\n'.format(name, mod) summary += " {} : {},\n".format(name, mod)
summary += '\n})' summary += "\n})"
return summary return summary
def get_aggregate_fn(agg): def get_aggregate_fn(agg):
"""Internal function to get the aggregation function for node data """Internal function to get the aggregation function for node data
generated from different relations. generated from different relations.
...@@ -213,29 +220,35 @@ def get_aggregate_fn(agg): ...@@ -213,29 +220,35 @@ def get_aggregate_fn(agg):
Aggregator function that takes a list of tensors to aggregate Aggregator function that takes a list of tensors to aggregate
and returns one aggregated tensor. and returns one aggregated tensor.
""" """
if agg == 'sum': if agg == "sum":
fn = nd.sum fn = nd.sum
elif agg == 'max': elif agg == "max":
fn = nd.max fn = nd.max
elif agg == 'min': elif agg == "min":
fn = nd.min fn = nd.min
elif agg == 'mean': elif agg == "mean":
fn = nd.mean fn = nd.mean
elif agg == 'stack': elif agg == "stack":
fn = None # will not be called fn = None # will not be called
else: else:
raise DGLError('Invalid cross type aggregator. Must be one of ' raise DGLError(
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg) "Invalid cross type aggregator. Must be one of "
if agg == 'stack': '"sum", "max", "min", "mean" or "stack". But got "%s"' % agg
)
if agg == "stack":
def stack_agg(inputs, dsttype): # pylint: disable=unused-argument def stack_agg(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0: if len(inputs) == 0:
return None return None
return nd.stack(*inputs, axis=1) return nd.stack(*inputs, axis=1)
return stack_agg return stack_agg
else: else:
def aggfn(inputs, dsttype): # pylint: disable=unused-argument def aggfn(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0: if len(inputs) == 0:
return None return None
stacked = nd.stack(*inputs, axis=0) stacked = nd.stack(*inputs, axis=0)
return fn(stacked, axis=0) return fn(stacked, axis=0)
return aggfn return aggfn
"""Utilities for pytorch NN package""" """Utilities for pytorch NN package"""
#pylint: disable=no-member, invalid-name # pylint: disable=no-member, invalid-name
from mxnet import nd, gluon
import numpy as np import numpy as np
from mxnet import gluon, nd
from ... import DGLGraph from ... import DGLGraph
def matmul_maybe_select(A, B): def matmul_maybe_select(A, B):
"""Perform Matrix multiplication C = A * B but A could be an integer id vector. """Perform Matrix multiplication C = A * B but A could be an integer id vector.
...@@ -46,6 +48,7 @@ def matmul_maybe_select(A, B): ...@@ -46,6 +48,7 @@ def matmul_maybe_select(A, B):
else: else:
return nd.dot(A, B) return nd.dot(A, B)
def bmm_maybe_select(A, B, index): def bmm_maybe_select(A, B, index):
"""Slice submatrices of A by the given index and perform bmm. """Slice submatrices of A by the given index and perform bmm.
...@@ -86,6 +89,7 @@ def bmm_maybe_select(A, B, index): ...@@ -86,6 +89,7 @@ def bmm_maybe_select(A, B, index):
BB = nd.take(B, index, axis=0) BB = nd.take(B, index, axis=0)
return nd.batch_dot(A.expand_dims(1), BB).squeeze(1) return nd.batch_dot(A.expand_dims(1), BB).squeeze(1)
def normalize(x, p=2, axis=1, eps=1e-12): def normalize(x, p=2, axis=1, eps=1e-12):
r"""Performs :math:`L_p` normalization of inputs over specified dimension. r"""Performs :math:`L_p` normalization of inputs over specified dimension.
...@@ -104,9 +108,12 @@ def normalize(x, p=2, axis=1, eps=1e-12): ...@@ -104,9 +108,12 @@ def normalize(x, p=2, axis=1, eps=1e-12):
dim (int): the dimension to reduce. Default: 1 dim (int): the dimension to reduce. Default: 1
eps (float): small value to avoid division by zero. Default: 1e-12 eps (float): small value to avoid division by zero. Default: 1e-12
""" """
denom = nd.clip(nd.norm(x, ord=p, axis=axis, keepdims=True), eps, float('inf')) denom = nd.clip(
nd.norm(x, ord=p, axis=axis, keepdims=True), eps, float("inf")
)
return x / denom return x / denom
class Sequential(gluon.nn.Sequential): class Sequential(gluon.nn.Sequential):
r"""A squential container for stacking graph neural network blocks r"""A squential container for stacking graph neural network blocks
...@@ -197,6 +204,7 @@ class Sequential(gluon.nn.Sequential): ...@@ -197,6 +204,7 @@ class Sequential(gluon.nn.Sequential):
[-119.23065 -26.78553 -111.11185 -166.08322 ]] [-119.23065 -26.78553 -111.11185 -166.08322 ]]
<NDArray 4x4 @cpu(0)> <NDArray 4x4 @cpu(0)>
""" """
def __init__(self, prefix=None, params=None): def __init__(self, prefix=None, params=None):
super(Sequential, self).__init__(prefix=prefix, params=params) super(Sequential, self).__init__(prefix=prefix, params=params)
...@@ -224,6 +232,8 @@ class Sequential(gluon.nn.Sequential): ...@@ -224,6 +232,8 @@ class Sequential(gluon.nn.Sequential):
feats = (feats,) feats = (feats,)
feats = module(graph, *feats) feats = module(graph, *feats)
else: else:
raise TypeError('The first argument of forward must be a DGLGraph' raise TypeError(
' or a list of DGLGraph s') "The first argument of forward must be a DGLGraph"
" or a list of DGLGraph s"
)
return feats return feats
...@@ -3,38 +3,66 @@ ...@@ -3,38 +3,66 @@
from .agnnconv import AGNNConv from .agnnconv import AGNNConv
from .appnpconv import APPNPConv from .appnpconv import APPNPConv
from .atomicconv import AtomicConv
from .cfconv import CFConv
from .chebconv import ChebConv from .chebconv import ChebConv
from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv
from .dgnconv import DGNConv
from .dotgatconv import DotGatConv
from .edgeconv import EdgeConv from .edgeconv import EdgeConv
from .egatconv import EGATConv
from .egnnconv import EGNNConv
from .gatconv import GATConv from .gatconv import GATConv
from .gatedgraphconv import GatedGraphConv
from .gatv2conv import GATv2Conv from .gatv2conv import GATv2Conv
from .egatconv import EGATConv from .gcn2conv import GCN2Conv
from .ginconv import GINConv from .ginconv import GINConv
from .gineconv import GINEConv from .gineconv import GINEConv
from .gmmconv import GMMConv from .gmmconv import GMMConv
from .graphconv import GraphConv, EdgeWeightNorm from .graphconv import EdgeWeightNorm, GraphConv
from .grouprevres import GroupRevRes
from .hgtconv import HGTConv
from .nnconv import NNConv from .nnconv import NNConv
from .pnaconv import PNAConv
from .relgraphconv import RelGraphConv from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv from .sageconv import SAGEConv
from .sgconv import SGConv from .sgconv import SGConv
from .tagconv import TAGConv from .tagconv import TAGConv
from .gatedgraphconv import GatedGraphConv
from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv
from .atomicconv import AtomicConv
from .cfconv import CFConv
from .dotgatconv import DotGatConv
from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
from .gcn2conv import GCN2Conv
from .hgtconv import HGTConv
from .grouprevres import GroupRevRes
from .egnnconv import EGNNConv
from .pnaconv import PNAConv
from .dgnconv import DGNConv
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv', __all__ = [
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GINEConv', "GraphConv",
'GatedGraphConv', 'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', "EdgeWeightNorm",
'DenseSAGEConv', 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', "GATConv",
'TWIRLSConv', 'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes', "GATv2Conv",
'EGNNConv', 'PNAConv', 'DGNConv'] "EGATConv",
"TAGConv",
"RelGraphConv",
"SAGEConv",
"SGConv",
"APPNPConv",
"GINConv",
"GINEConv",
"GatedGraphConv",
"GMMConv",
"ChebConv",
"AGNNConv",
"NNConv",
"DenseGraphConv",
"DenseSAGEConv",
"DenseChebConv",
"EdgeConv",
"AtomicConv",
"CFConv",
"DotGatConv",
"TWIRLSConv",
"TWIRLSUnfoldingAndAttention",
"GCN2Conv",
"HGTConv",
"GroupRevRes",
"EGNNConv",
"PNAConv",
"DGNConv",
]
...@@ -5,9 +5,9 @@ from torch import nn ...@@ -5,9 +5,9 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .... import function as fn from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair from ....utils import expand_as_pair
from ...functional import edge_softmax
class AGNNConv(nn.Module): class AGNNConv(nn.Module):
...@@ -74,16 +74,16 @@ class AGNNConv(nn.Module): ...@@ -74,16 +74,16 @@ class AGNNConv(nn.Module):
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],
grad_fn=<BinaryReduceBackward>) grad_fn=<BinaryReduceBackward>)
""" """
def __init__(self,
init_beta=1., def __init__(
learn_beta=True, self, init_beta=1.0, learn_beta=True, allow_zero_in_degree=False
allow_zero_in_degree=False): ):
super(AGNNConv, self).__init__() super(AGNNConv, self).__init__()
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
if learn_beta: if learn_beta:
self.beta = nn.Parameter(th.Tensor([init_beta])) self.beta = nn.Parameter(th.Tensor([init_beta]))
else: else:
self.register_buffer('beta', th.Tensor([init_beta])) self.register_buffer("beta", th.Tensor([init_beta]))
def set_allow_zero_in_degree(self, set_value): def set_allow_zero_in_degree(self, set_value):
r""" r"""
...@@ -133,26 +133,28 @@ class AGNNConv(nn.Module): ...@@ -133,26 +133,28 @@ class AGNNConv(nn.Module):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any(): if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.srcdata['norm_h'] = F.normalize(feat_src, p=2, dim=-1) graph.srcdata["norm_h"] = F.normalize(feat_src, p=2, dim=-1)
if isinstance(feat, tuple) or graph.is_block: if isinstance(feat, tuple) or graph.is_block:
graph.dstdata['norm_h'] = F.normalize(feat_dst, p=2, dim=-1) graph.dstdata["norm_h"] = F.normalize(feat_dst, p=2, dim=-1)
# compute cosine distance # compute cosine distance
graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos')) graph.apply_edges(fn.u_dot_v("norm_h", "norm_h", "cos"))
cos = graph.edata.pop('cos') cos = graph.edata.pop("cos")
e = self.beta * cos e = self.beta * cos
graph.edata['p'] = edge_softmax(graph, e) graph.edata["p"] = edge_softmax(graph, e)
graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h')) graph.update_all(fn.u_mul_e("h", "p", "m"), fn.sum("m", "h"))
return graph.dstdata.pop('h') return graph.dstdata.pop("h")
...@@ -56,10 +56,7 @@ class APPNPConv(nn.Module): ...@@ -56,10 +56,7 @@ class APPNPConv(nn.Module):
0.5000]]) 0.5000]])
""" """
def __init__(self, def __init__(self, k, alpha, edge_drop=0.0):
k,
alpha,
edge_drop=0.):
super(APPNPConv, self).__init__() super(APPNPConv, self).__init__()
self._k = k self._k = k
self._alpha = alpha self._alpha = alpha
...@@ -94,28 +91,29 @@ class APPNPConv(nn.Module): ...@@ -94,28 +91,29 @@ class APPNPConv(nn.Module):
with graph.local_scope(): with graph.local_scope():
if edge_weight is None: if edge_weight is None:
src_norm = th.pow( src_norm = th.pow(
graph.out_degrees().float().clamp(min=1), -0.5) graph.out_degrees().float().clamp(min=1), -0.5
)
shp = src_norm.shape + (1,) * (feat.dim() - 1) shp = src_norm.shape + (1,) * (feat.dim() - 1)
src_norm = th.reshape(src_norm, shp).to(feat.device) src_norm = th.reshape(src_norm, shp).to(feat.device)
dst_norm = th.pow( dst_norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
graph.in_degrees().float().clamp(min=1), -0.5)
shp = dst_norm.shape + (1,) * (feat.dim() - 1) shp = dst_norm.shape + (1,) * (feat.dim() - 1)
dst_norm = th.reshape(dst_norm, shp).to(feat.device) dst_norm = th.reshape(dst_norm, shp).to(feat.device)
else: else:
edge_weight = EdgeWeightNorm( edge_weight = EdgeWeightNorm("both")(graph, edge_weight)
'both')(graph, edge_weight)
feat_0 = feat feat_0 = feat
for _ in range(self._k): for _ in range(self._k):
# normalization by src node # normalization by src node
if edge_weight is None: if edge_weight is None:
feat = feat * src_norm feat = feat * src_norm
graph.ndata['h'] = feat graph.ndata["h"] = feat
w = th.ones(graph.number_of_edges(), w = (
1) if edge_weight is None else edge_weight th.ones(graph.number_of_edges(), 1)
graph.edata['w'] = self.edge_drop(w).to(feat.device) if edge_weight is None
graph.update_all(fn.u_mul_e('h', 'w', 'm'), else edge_weight
fn.sum('m', 'h')) )
feat = graph.ndata.pop('h') graph.edata["w"] = self.edge_drop(w).to(feat.device)
graph.update_all(fn.u_mul_e("h", "w", "m"), fn.sum("m", "h"))
feat = graph.ndata.pop("h")
# normalization by dst node # normalization by dst node
if edge_weight is None: if edge_weight is None:
feat = feat * dst_norm feat = feat * dst_norm
......
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
class RadialPooling(nn.Module): class RadialPooling(nn.Module):
r"""Radial pooling from `Atomic Convolutional Networks for r"""Radial pooling from `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__ Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__
...@@ -42,15 +43,21 @@ class RadialPooling(nn.Module): ...@@ -42,15 +43,21 @@ class RadialPooling(nn.Module):
rbf_kernel_scaling : float32 tensor of shape (K) rbf_kernel_scaling : float32 tensor of shape (K)
:math:`\gamma_k` in the equations above. K for the number of radial filters. :math:`\gamma_k` in the equations above. K for the number of radial filters.
""" """
def __init__(self, interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling):
def __init__(
self, interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling
):
super(RadialPooling, self).__init__() super(RadialPooling, self).__init__()
self.interaction_cutoffs = nn.Parameter( self.interaction_cutoffs = nn.Parameter(
interaction_cutoffs.reshape(-1, 1, 1), requires_grad=True) interaction_cutoffs.reshape(-1, 1, 1), requires_grad=True
)
self.rbf_kernel_means = nn.Parameter( self.rbf_kernel_means = nn.Parameter(
rbf_kernel_means.reshape(-1, 1, 1), requires_grad=True) rbf_kernel_means.reshape(-1, 1, 1), requires_grad=True
)
self.rbf_kernel_scaling = nn.Parameter( self.rbf_kernel_scaling = nn.Parameter(
rbf_kernel_scaling.reshape(-1, 1, 1), requires_grad=True) rbf_kernel_scaling.reshape(-1, 1, 1), requires_grad=True
)
def forward(self, distances): def forward(self, distances):
""" """
...@@ -69,14 +76,19 @@ class RadialPooling(nn.Module): ...@@ -69,14 +76,19 @@ class RadialPooling(nn.Module):
Float32 tensor of shape (K, E, 1) Float32 tensor of shape (K, E, 1)
Transformed edge distances. K for the number of radial filters. Transformed edge distances. K for the number of radial filters.
""" """
scaled_euclidean_distance = - self.rbf_kernel_scaling * \ scaled_euclidean_distance = (
(distances - self.rbf_kernel_means) ** 2 # (K, E, 1) -self.rbf_kernel_scaling * (distances - self.rbf_kernel_means) ** 2
rbf_kernel_results = th.exp(scaled_euclidean_distance) # (K, E, 1) ) # (K, E, 1)
rbf_kernel_results = th.exp(scaled_euclidean_distance) # (K, E, 1)
cos_values = 0.5 * (th.cos(np.pi * distances / self.interaction_cutoffs) + 1) # (K, E, 1)
cos_values = 0.5 * (
th.cos(np.pi * distances / self.interaction_cutoffs) + 1
) # (K, E, 1)
cutoff_values = th.where( cutoff_values = th.where(
distances <= self.interaction_cutoffs, distances <= self.interaction_cutoffs,
cos_values, th.zeros_like(cos_values)) # (K, E, 1) cos_values,
th.zeros_like(cos_values),
) # (K, E, 1)
# Note that there appears to be an inconsistency between the paper and # Note that there appears to be an inconsistency between the paper and
# DeepChem's implementation. In the paper, the scaled_euclidean_distance first # DeepChem's implementation. In the paper, the scaled_euclidean_distance first
...@@ -84,6 +96,7 @@ class RadialPooling(nn.Module): ...@@ -84,6 +96,7 @@ class RadialPooling(nn.Module):
# the practice of DeepChem. # the practice of DeepChem.
return rbf_kernel_results * cutoff_values return rbf_kernel_results * cutoff_values
def msg_func(edges): def msg_func(edges):
""" """
...@@ -103,8 +116,12 @@ def msg_func(edges): ...@@ -103,8 +116,12 @@ def msg_func(edges):
radial filters and T for the number of features to use radial filters and T for the number of features to use
(types of atomic number in the paper). (types of atomic number in the paper).
""" """
return {'m': th.einsum( return {
'ij,ik->ijk', edges.src['hv'], edges.data['he']).view(len(edges), -1)} "m": th.einsum("ij,ik->ijk", edges.src["hv"], edges.data["he"]).view(
len(edges), -1
)
}
def reduce_func(nodes): def reduce_func(nodes):
""" """
...@@ -125,7 +142,8 @@ def reduce_func(nodes): ...@@ -125,7 +142,8 @@ def reduce_func(nodes):
radial filters and T for the number of features to use radial filters and T for the number of features to use
(types of atomic number in the paper). (types of atomic number in the paper).
""" """
return {'hv_new': nodes.mailbox['m'].sum(1)} return {"hv_new": nodes.mailbox["m"].sum(1)}
class AtomicConv(nn.Module): class AtomicConv(nn.Module):
r"""Atomic Convolution Layer from `Atomic Convolutional Networks for r"""Atomic Convolution Layer from `Atomic Convolutional Networks for
...@@ -219,19 +237,29 @@ class AtomicConv(nn.Module): ...@@ -219,19 +237,29 @@ class AtomicConv(nn.Module):
[0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000],
[0.0000, 0.0000, 0.0000]], grad_fn=<ViewBackward>) [0.0000, 0.0000, 0.0000]], grad_fn=<ViewBackward>)
""" """
def __init__(self, interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use=None): def __init__(
self,
interaction_cutoffs,
rbf_kernel_means,
rbf_kernel_scaling,
features_to_use=None,
):
super(AtomicConv, self).__init__() super(AtomicConv, self).__init__()
self.radial_pooling = RadialPooling(interaction_cutoffs=interaction_cutoffs, self.radial_pooling = RadialPooling(
rbf_kernel_means=rbf_kernel_means, interaction_cutoffs=interaction_cutoffs,
rbf_kernel_scaling=rbf_kernel_scaling) rbf_kernel_means=rbf_kernel_means,
rbf_kernel_scaling=rbf_kernel_scaling,
)
if features_to_use is None: if features_to_use is None:
self.num_channels = 1 self.num_channels = 1
self.features_to_use = None self.features_to_use = None
else: else:
self.num_channels = len(features_to_use) self.num_channels = len(features_to_use)
self.features_to_use = nn.Parameter(features_to_use, requires_grad=False) self.features_to_use = nn.Parameter(
features_to_use, requires_grad=False
)
def forward(self, graph, feat, distances): def forward(self, graph, feat, distances):
""" """
...@@ -257,11 +285,15 @@ class AtomicConv(nn.Module): ...@@ -257,11 +285,15 @@ class AtomicConv(nn.Module):
number of radial filters, and :math:`T` for the number of types of atomic numbers. number of radial filters, and :math:`T` for the number of types of atomic numbers.
""" """
with graph.local_scope(): with graph.local_scope():
radial_pooled_values = self.radial_pooling(distances) # (K, E, 1) radial_pooled_values = self.radial_pooling(distances) # (K, E, 1)
if self.features_to_use is not None: if self.features_to_use is not None:
feat = (feat == self.features_to_use).float() # (V, T) feat = (feat == self.features_to_use).float() # (V, T)
graph.ndata['hv'] = feat graph.ndata["hv"] = feat
graph.edata['he'] = radial_pooled_values.transpose(1, 0).squeeze(-1) # (E, K) graph.edata["he"] = radial_pooled_values.transpose(1, 0).squeeze(
-1
) # (E, K)
graph.update_all(msg_func, reduce_func) graph.update_all(msg_func, reduce_func)
return graph.ndata['hv_new'].view(graph.number_of_nodes(), -1) # (V, K * T) return graph.ndata["hv_new"].view(
graph.number_of_nodes(), -1
) # (V, K * T)
...@@ -5,6 +5,7 @@ import torch.nn as nn ...@@ -5,6 +5,7 @@ import torch.nn as nn
from .... import function as fn from .... import function as fn
class ShiftedSoftplus(nn.Module): class ShiftedSoftplus(nn.Module):
r"""Applies the element-wise function: r"""Applies the element-wise function:
...@@ -18,6 +19,7 @@ class ShiftedSoftplus(nn.Module): ...@@ -18,6 +19,7 @@ class ShiftedSoftplus(nn.Module):
shift : int shift : int
:math:`\text{shift}` value for the mathematical formulation. Default to 2. :math:`\text{shift}` value for the mathematical formulation. Default to 2.
""" """
def __init__(self, beta=1, shift=2, threshold=20): def __init__(self, beta=1, shift=2, threshold=20):
super(ShiftedSoftplus, self).__init__() super(ShiftedSoftplus, self).__init__()
...@@ -43,6 +45,7 @@ class ShiftedSoftplus(nn.Module): ...@@ -43,6 +45,7 @@ class ShiftedSoftplus(nn.Module):
""" """
return self.softplus(inputs) - np.log(float(self.shift)) return self.softplus(inputs) - np.log(float(self.shift))
class CFConv(nn.Module): class CFConv(nn.Module):
r"""CFConv from `SchNet: A continuous-filter convolutional neural network for r"""CFConv from `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__ modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__
...@@ -87,6 +90,7 @@ class CFConv(nn.Module): ...@@ -87,6 +90,7 @@ class CFConv(nn.Module):
[-0.1209, -0.2289], [-0.1209, -0.2289],
[-0.1283, -0.2240]], grad_fn=<SubBackward0>) [-0.1283, -0.2240]], grad_fn=<SubBackward0>)
""" """
def __init__(self, node_in_feats, edge_in_feats, hidden_feats, out_feats): def __init__(self, node_in_feats, edge_in_feats, hidden_feats, out_feats):
super(CFConv, self).__init__() super(CFConv, self).__init__()
...@@ -94,12 +98,11 @@ class CFConv(nn.Module): ...@@ -94,12 +98,11 @@ class CFConv(nn.Module):
nn.Linear(edge_in_feats, hidden_feats), nn.Linear(edge_in_feats, hidden_feats),
ShiftedSoftplus(), ShiftedSoftplus(),
nn.Linear(hidden_feats, hidden_feats), nn.Linear(hidden_feats, hidden_feats),
ShiftedSoftplus() ShiftedSoftplus(),
) )
self.project_node = nn.Linear(node_in_feats, hidden_feats) self.project_node = nn.Linear(node_in_feats, hidden_feats)
self.project_out = nn.Sequential( self.project_out = nn.Sequential(
nn.Linear(hidden_feats, out_feats), nn.Linear(hidden_feats, out_feats), ShiftedSoftplus()
ShiftedSoftplus()
) )
def forward(self, g, node_feats, edge_feats): def forward(self, g, node_feats, edge_feats):
...@@ -136,7 +139,7 @@ class CFConv(nn.Module): ...@@ -136,7 +139,7 @@ class CFConv(nn.Module):
node_feats_src, _ = node_feats node_feats_src, _ = node_feats
else: else:
node_feats_src = node_feats node_feats_src = node_feats
g.srcdata['hv'] = self.project_node(node_feats_src) g.srcdata["hv"] = self.project_node(node_feats_src)
g.edata['he'] = self.project_edge(edge_feats) g.edata["he"] = self.project_edge(edge_feats)
g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h')) g.update_all(fn.u_mul_e("hv", "he", "m"), fn.sum("m", "h"))
return self.project_out(g.dstdata['h']) return self.project_out(g.dstdata["h"])
"""Torch Module for Chebyshev Spectral Graph Convolution layer""" """Torch Module for Chebyshev Spectral Graph Convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import torch as th import torch as th
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from .... import broadcast_nodes
from .... import function as fn
from ....base import dgl_warning from ....base import dgl_warning
from .... import broadcast_nodes, function as fn
class ChebConv(nn.Module): class ChebConv(nn.Module):
...@@ -60,12 +61,7 @@ class ChebConv(nn.Module): ...@@ -60,12 +61,7 @@ class ChebConv(nn.Module):
[-0.2370, 3.0164]], grad_fn=<AddBackward0>) [-0.2370, 3.0164]], grad_fn=<AddBackward0>)
""" """
def __init__(self, def __init__(self, in_feats, out_feats, k, activation=F.relu, bias=True):
in_feats,
out_feats,
k,
activation=F.relu,
bias=True):
super(ChebConv, self).__init__() super(ChebConv, self).__init__()
self._k = k self._k = k
self._in_feats = in_feats self._in_feats = in_feats
...@@ -97,20 +93,25 @@ class ChebConv(nn.Module): ...@@ -97,20 +93,25 @@ class ChebConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
def unnLaplacian(feat, D_invsqrt, graph): def unnLaplacian(feat, D_invsqrt, graph):
""" Operation Feat * D^-1/2 A D^-1/2 """ """Operation Feat * D^-1/2 A D^-1/2"""
graph.ndata['h'] = feat * D_invsqrt graph.ndata["h"] = feat * D_invsqrt
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
return graph.ndata.pop('h') * D_invsqrt return graph.ndata.pop("h") * D_invsqrt
with graph.local_scope(): with graph.local_scope():
D_invsqrt = th.pow(graph.in_degrees().float().clamp( D_invsqrt = (
min=1), -0.5).unsqueeze(-1).to(feat.device) th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
.unsqueeze(-1)
.to(feat.device)
)
if lambda_max is None: if lambda_max is None:
dgl_warning( dgl_warning(
"lambda_max is not provided, using default value of 2. " "lambda_max is not provided, using default value of 2. "
"Please use dgl.laplacian_lambda_max to compute the eigenvalues.") "Please use dgl.laplacian_lambda_max to compute the eigenvalues."
)
lambda_max = [2] * graph.batch_size lambda_max = [2] * graph.batch_size
if isinstance(lambda_max, list): if isinstance(lambda_max, list):
...@@ -120,7 +121,7 @@ class ChebConv(nn.Module): ...@@ -120,7 +121,7 @@ class ChebConv(nn.Module):
# broadcast from (B, 1) to (N, 1) # broadcast from (B, 1) to (N, 1)
lambda_max = broadcast_nodes(graph, lambda_max) lambda_max = broadcast_nodes(graph, lambda_max)
re_norm = 2. / lambda_max re_norm = 2.0 / lambda_max
# X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t # X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t
Xt = X_0 = feat Xt = X_0 = feat
...@@ -128,14 +129,14 @@ class ChebConv(nn.Module): ...@@ -128,14 +129,14 @@ class ChebConv(nn.Module):
# X_1(f) # X_1(f)
if self._k > 1: if self._k > 1:
h = unnLaplacian(X_0, D_invsqrt, graph) h = unnLaplacian(X_0, D_invsqrt, graph)
X_1 = - re_norm * h + X_0 * (re_norm - 1) X_1 = -re_norm * h + X_0 * (re_norm - 1)
# Concatenate Xt and X_1 # Concatenate Xt and X_1
Xt = th.cat((Xt, X_1), 1) Xt = th.cat((Xt, X_1), 1)
# Xi(x), i = 2...k # Xi(x), i = 2...k
for _ in range(2, self._k): for _ in range(2, self._k):
h = unnLaplacian(X_1, D_invsqrt, graph) h = unnLaplacian(X_1, D_invsqrt, graph)
X_i = - 2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0 X_i = -2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0
# Concatenate Xt and X_i # Concatenate Xt and X_i
Xt = th.cat((Xt, X_i), 1) Xt = th.cat((Xt, X_i), 1)
X_1, X_0 = X_i, X_1 X_1, X_0 = X_i, X_1
......
...@@ -53,11 +53,8 @@ class DenseChebConv(nn.Module): ...@@ -53,11 +53,8 @@ class DenseChebConv(nn.Module):
-------- --------
`ChebConv <https://docs.dgl.ai/api/python/nn.pytorch.html#chebconv>`__ `ChebConv <https://docs.dgl.ai/api/python/nn.pytorch.html#chebconv>`__
""" """
def __init__(self,
in_feats, def __init__(self, in_feats, out_feats, k, bias=True):
out_feats,
k,
bias=True):
super(DenseChebConv, self).__init__() super(DenseChebConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
...@@ -66,7 +63,7 @@ class DenseChebConv(nn.Module): ...@@ -66,7 +63,7 @@ class DenseChebConv(nn.Module):
if bias: if bias:
self.bias = nn.Parameter(th.Tensor(out_feats)) self.bias = nn.Parameter(th.Tensor(out_feats))
else: else:
self.register_buffer('bias', None) self.register_buffer("bias", None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -74,7 +71,7 @@ class DenseChebConv(nn.Module): ...@@ -74,7 +71,7 @@ class DenseChebConv(nn.Module):
if self.bias is not None: if self.bias is not None:
init.zeros_(self.bias) init.zeros_(self.bias)
for i in range(self._k): for i in range(self._k):
init.xavier_normal_(self.W[i], init.calculate_gain('relu')) init.xavier_normal_(self.W[i], init.calculate_gain("relu"))
def forward(self, adj, feat, lambda_max=None): def forward(self, adj, feat, lambda_max=None):
r"""Compute (Dense) Chebyshev Spectral Graph Convolution layer r"""Compute (Dense) Chebyshev Spectral Graph Convolution layer
...@@ -120,7 +117,7 @@ class DenseChebConv(nn.Module): ...@@ -120,7 +117,7 @@ class DenseChebConv(nn.Module):
Zs = th.stack(Z, 0) # (k, n, n) Zs = th.stack(Z, 0) # (k, n, n)
Zh = (Zs @ feat.unsqueeze(0) @ self.W) Zh = Zs @ feat.unsqueeze(0) @ self.W
Zh = Zh.sum(0) Zh = Zh.sum(0)
if self.bias is not None: if self.bias is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment