Unverified Commit b6c4deb4 authored by Axel Nilsson's avatar Axel Nilsson Committed by GitHub
Browse files

[NN] ChebConv update (#1460)



* New version for the convolutional layer

* Minor changes

* Minor changes

* Update python/dgl/nn/pytorch/conv/chebconv.py
Co-authored-by: default avatarZihao Ye <zihaoye.cs@gmail.com>

* Resolved variable miss-naming, import simplifying and raising warning

* Added dg_warnings instead of warnings.warn

* add doc

* upd

* upd
Co-authored-by: default avatarZihao Ye <zihaoye.cs@gmail.com>
Co-authored-by: default avataryzh119 <expye@outlook.com>
parent d4daceb8
......@@ -2,8 +2,9 @@
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from ....base import dgl_warning
from .... import laplacian_lambda_max, broadcast_nodes, function as fn
......@@ -31,6 +32,8 @@ class ChebConv(nn.Module):
Number of output features.
k : int
Chebyshev filter size.
activation : function, optional
Activation function, default is ReLu.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
"""
......@@ -39,29 +42,14 @@ class ChebConv(nn.Module):
in_feats,
out_feats,
k,
activation=F.relu,
bias=True):
super(ChebConv, self).__init__()
self._k = k
self._in_feats = in_feats
self._out_feats = out_feats
self.fc = nn.ModuleList([
nn.Linear(in_feats, out_feats, bias=False) for _ in range(k)
])
self._k = k
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
if self.bias is not None:
init.zeros_(self.bias)
for module in self.fc.modules():
if isinstance(module, nn.Linear):
init.xavier_normal_(module.weight, init.calculate_gain('relu'))
if module.bias is not None:
init.zeros_(module.bias)
self.activation = activation
self.linear = nn.Linear(k * in_feats, out_feats, bias)
def forward(self, graph, feat, lambda_max=None):
r"""Compute ChebNet layer.
......@@ -86,42 +74,58 @@ class ChebConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
def unnLaplacian(feat, D_invsqrt, graph):
""" Operation Feat * D^-1/2 A D^-1/2 """
graph.ndata['h'] = feat * D_invsqrt
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
return graph.ndata.pop('h') * D_invsqrt
with graph.local_scope():
norm = th.pow(
graph.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1).to(feat.device)
D_invsqrt = th.pow(graph.in_degrees().float().clamp(
min=1), -0.5).unsqueeze(-1).to(feat.device)
if lambda_max is None:
lambda_max = laplacian_lambda_max(graph)
try:
lambda_max = laplacian_lambda_max(graph)
except BaseException:
# if the largest eigenvalue is not found
dgl_warning(
"Largest eigonvalue not found, using default value 2 for lambda_max",
RuntimeWarning)
lambda_max = th.Tensor(2).to(feat.device)
if isinstance(lambda_max, list):
lambda_max = th.Tensor(lambda_max).to(feat.device)
if lambda_max.dim() == 1:
lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1)
# broadcast from (B, 1) to (N, 1)
lambda_max = broadcast_nodes(graph, lambda_max)
# T0(X)
Tx_0 = feat
rst = self.fc[0](Tx_0)
# T1(X)
re_norm = 2. / lambda_max
# X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t
Xt = X_0 = feat
# X_1(f)
if self._k > 1:
graph.ndata['h'] = Tx_0 * norm
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h = graph.ndata.pop('h') * norm
# Λ = 2 * (I - D ^ -1/2 A D ^ -1/2) / lambda_max - I
# = - 2(D ^ -1/2 A D ^ -1/2) / lambda_max + (2 / lambda_max - 1) I
Tx_1 = -2. * h / lambda_max + Tx_0 * (2. / lambda_max - 1)
rst = rst + self.fc[1](Tx_1)
# Ti(x), i = 2...k
for i in range(2, self._k):
graph.ndata['h'] = Tx_1 * norm
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h = graph.ndata.pop('h') * norm
# Tx_k = 2 * Λ * Tx_(k-1) - Tx_(k-2)
# = - 4(D ^ -1/2 A D ^ -1/2) / lambda_max Tx_(k-1) +
# (4 / lambda_max - 2) Tx_(k-1) -
# Tx_(k-2)
Tx_2 = -4. * h / lambda_max + Tx_1 * (4. / lambda_max - 2) - Tx_0
rst = rst + self.fc[i](Tx_2)
Tx_1, Tx_0 = Tx_2, Tx_1
# add bias
if self.bias is not None:
rst = rst + self.bias
return rst
h = unnLaplacian(X_0, D_invsqrt, graph)
X_1 = - re_norm * h + X_0 * (re_norm - 1)
# Concatenate Xt and X_1
Xt = th.cat((Xt, X_1), 1)
# Xi(x), i = 2...k
for _ in range(2, self._k):
h = unnLaplacian(X_1, D_invsqrt, graph)
X_i = - 2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0
# Concatenate Xt and X_i
Xt = th.cat((Xt, X_i), 1)
X_1, X_0 = X_i, X_1
# linear projection
h = self.linear(Xt)
# activation
if self.activation:
h = self.activation(h)
return h
......@@ -679,17 +679,19 @@ def test_dense_cheb_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).to_dense()
cheb = nn.ChebConv(5, 2, k)
cheb = nn.ChebConv(5, 2, k, None)
dense_cheb = nn.DenseChebConv(5, 2, k)
for i in range(len(cheb.fc)):
dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
if cheb.bias is not None:
dense_cheb.bias.data = cheb.bias.data
#for i in range(len(cheb.fc)):
# dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, 2)
if cheb.linear.bias is not None:
dense_cheb.bias.data = cheb.linear.bias.data
feat = F.randn((100, 5))
cheb = cheb.to(ctx)
dense_cheb = dense_cheb.to(ctx)
out_cheb = cheb(g, feat, [2.0])
out_dense_cheb = dense_cheb(adj, feat, 2.0)
print(k, out_cheb, out_dense_cheb)
assert F.allclose(out_cheb, out_dense_cheb)
def test_sequential():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment