"docs/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "c83ec3555dc1399789f8157475cb5fef7e691575"
Unverified Commit b6c4deb4 authored by Axel Nilsson's avatar Axel Nilsson Committed by GitHub
Browse files

[NN] ChebConv update (#1460)



* New version for the convolutional layer

* Minor changes

* Minor changes

* Update python/dgl/nn/pytorch/conv/chebconv.py
Co-authored-by: default avatarZihao Ye <zihaoye.cs@gmail.com>

* Resolved variable miss-naming, import simplifying and raising warning

* Added dg_warnings instead of warnings.warn

* add doc

* upd

* upd
Co-authored-by: default avatarZihao Ye <zihaoye.cs@gmail.com>
Co-authored-by: default avataryzh119 <expye@outlook.com>
parent d4daceb8
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import torch as th import torch as th
from torch import nn from torch import nn
from torch.nn import init import torch.nn.functional as F
from ....base import dgl_warning
from .... import laplacian_lambda_max, broadcast_nodes, function as fn from .... import laplacian_lambda_max, broadcast_nodes, function as fn
...@@ -31,6 +32,8 @@ class ChebConv(nn.Module): ...@@ -31,6 +32,8 @@ class ChebConv(nn.Module):
Number of output features. Number of output features.
k : int k : int
Chebyshev filter size. Chebyshev filter size.
activation : function, optional
Activation function, default is ReLu.
bias : bool, optional bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``. If True, adds a learnable bias to the output. Default: ``True``.
""" """
...@@ -39,29 +42,14 @@ class ChebConv(nn.Module): ...@@ -39,29 +42,14 @@ class ChebConv(nn.Module):
in_feats, in_feats,
out_feats, out_feats,
k, k,
activation=F.relu,
bias=True): bias=True):
super(ChebConv, self).__init__() super(ChebConv, self).__init__()
self._k = k
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self.fc = nn.ModuleList([ self.activation = activation
nn.Linear(in_feats, out_feats, bias=False) for _ in range(k) self.linear = nn.Linear(k * in_feats, out_feats, bias)
])
self._k = k
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
if self.bias is not None:
init.zeros_(self.bias)
for module in self.fc.modules():
if isinstance(module, nn.Linear):
init.xavier_normal_(module.weight, init.calculate_gain('relu'))
if module.bias is not None:
init.zeros_(module.bias)
def forward(self, graph, feat, lambda_max=None): def forward(self, graph, feat, lambda_max=None):
r"""Compute ChebNet layer. r"""Compute ChebNet layer.
...@@ -86,42 +74,58 @@ class ChebConv(nn.Module): ...@@ -86,42 +74,58 @@ class ChebConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
def unnLaplacian(feat, D_invsqrt, graph):
""" Operation Feat * D^-1/2 A D^-1/2 """
graph.ndata['h'] = feat * D_invsqrt
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
return graph.ndata.pop('h') * D_invsqrt
with graph.local_scope(): with graph.local_scope():
norm = th.pow( D_invsqrt = th.pow(graph.in_degrees().float().clamp(
graph.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1).to(feat.device) min=1), -0.5).unsqueeze(-1).to(feat.device)
if lambda_max is None: if lambda_max is None:
lambda_max = laplacian_lambda_max(graph) try:
lambda_max = laplacian_lambda_max(graph)
except BaseException:
# if the largest eigenvalue is not found
dgl_warning(
"Largest eigonvalue not found, using default value 2 for lambda_max",
RuntimeWarning)
lambda_max = th.Tensor(2).to(feat.device)
if isinstance(lambda_max, list): if isinstance(lambda_max, list):
lambda_max = th.Tensor(lambda_max).to(feat.device) lambda_max = th.Tensor(lambda_max).to(feat.device)
if lambda_max.dim() == 1: if lambda_max.dim() == 1:
lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1) lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1)
# broadcast from (B, 1) to (N, 1) # broadcast from (B, 1) to (N, 1)
lambda_max = broadcast_nodes(graph, lambda_max) lambda_max = broadcast_nodes(graph, lambda_max)
# T0(X) re_norm = 2. / lambda_max
Tx_0 = feat
rst = self.fc[0](Tx_0) # X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t
# T1(X) Xt = X_0 = feat
# X_1(f)
if self._k > 1: if self._k > 1:
graph.ndata['h'] = Tx_0 * norm h = unnLaplacian(X_0, D_invsqrt, graph)
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) X_1 = - re_norm * h + X_0 * (re_norm - 1)
h = graph.ndata.pop('h') * norm # Concatenate Xt and X_1
# Λ = 2 * (I - D ^ -1/2 A D ^ -1/2) / lambda_max - I Xt = th.cat((Xt, X_1), 1)
# = - 2(D ^ -1/2 A D ^ -1/2) / lambda_max + (2 / lambda_max - 1) I
Tx_1 = -2. * h / lambda_max + Tx_0 * (2. / lambda_max - 1) # Xi(x), i = 2...k
rst = rst + self.fc[1](Tx_1) for _ in range(2, self._k):
# Ti(x), i = 2...k h = unnLaplacian(X_1, D_invsqrt, graph)
for i in range(2, self._k): X_i = - 2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0
graph.ndata['h'] = Tx_1 * norm # Concatenate Xt and X_i
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) Xt = th.cat((Xt, X_i), 1)
h = graph.ndata.pop('h') * norm X_1, X_0 = X_i, X_1
# Tx_k = 2 * Λ * Tx_(k-1) - Tx_(k-2)
# = - 4(D ^ -1/2 A D ^ -1/2) / lambda_max Tx_(k-1) + # linear projection
# (4 / lambda_max - 2) Tx_(k-1) - h = self.linear(Xt)
# Tx_(k-2)
Tx_2 = -4. * h / lambda_max + Tx_1 * (4. / lambda_max - 2) - Tx_0 # activation
rst = rst + self.fc[i](Tx_2) if self.activation:
Tx_1, Tx_0 = Tx_2, Tx_1 h = self.activation(h)
# add bias
if self.bias is not None: return h
rst = rst + self.bias
return rst
...@@ -679,17 +679,19 @@ def test_dense_cheb_conv(): ...@@ -679,17 +679,19 @@ def test_dense_cheb_conv():
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).to_dense() adj = g.adjacency_matrix(ctx=ctx).to_dense()
cheb = nn.ChebConv(5, 2, k) cheb = nn.ChebConv(5, 2, k, None)
dense_cheb = nn.DenseChebConv(5, 2, k) dense_cheb = nn.DenseChebConv(5, 2, k)
for i in range(len(cheb.fc)): #for i in range(len(cheb.fc)):
dense_cheb.W.data[i] = cheb.fc[i].weight.data.t() # dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
if cheb.bias is not None: dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, 2)
dense_cheb.bias.data = cheb.bias.data if cheb.linear.bias is not None:
dense_cheb.bias.data = cheb.linear.bias.data
feat = F.randn((100, 5)) feat = F.randn((100, 5))
cheb = cheb.to(ctx) cheb = cheb.to(ctx)
dense_cheb = dense_cheb.to(ctx) dense_cheb = dense_cheb.to(ctx)
out_cheb = cheb(g, feat, [2.0]) out_cheb = cheb(g, feat, [2.0])
out_dense_cheb = dense_cheb(adj, feat, 2.0) out_dense_cheb = dense_cheb(adj, feat, 2.0)
print(k, out_cheb, out_dense_cheb)
assert F.allclose(out_cheb, out_dense_cheb) assert F.allclose(out_cheb, out_dense_cheb)
def test_sequential(): def test_sequential():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment