Unverified Commit d6c12f07 authored by 鄢振宇Michael Yan's avatar 鄢振宇Michael Yan Committed by GitHub
Browse files

[NN] Fix #5642 (#5652)

parent ea706cae
......@@ -120,24 +120,28 @@ class ChebConv(nn.Module):
lambda_max = broadcast_nodes(graph, lambda_max)
re_norm = 2.0 / lambda_max
# X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t
Xt = X_0 = feat
# X_0 is the raw feature, Xt is the list of X_0, X_1, ... X_t
X_0 = feat
Xt = [X_0]
# X_1(f)
if self._k > 1:
h = unnLaplacian(X_0, D_invsqrt, graph)
X_1 = -re_norm * h + X_0 * (re_norm - 1)
# Concatenate Xt and X_1
Xt = th.cat((Xt, X_1), 1)
# Append X_1 to Xt
Xt.append(X_1)
# Xi(x), i = 2...k
for _ in range(2, self._k):
h = unnLaplacian(X_1, D_invsqrt, graph)
X_i = -2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0
# Concatenate Xt and X_i
Xt = th.cat((Xt, X_i), 1)
# Add X_1 to Xt
Xt.append(X_i)
X_1, X_0 = X_i, X_1
# Create the concatenation
Xt = th.cat(Xt, dim=1)
# linear projection
h = self.linear(Xt)
......
......@@ -127,24 +127,28 @@ class ChebConv(layers.Layer):
lambda_max = broadcast_nodes(graph, lambda_max)
re_norm = 2.0 / lambda_max
# X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t
Xt = X_0 = feat
# X_0 is the raw feature, Xt is the list of X_0, X_1, ... X_t
X_0 = feat
Xt = [X_0]
# X_1(f)
if self._k > 1:
h = unnLaplacian(X_0, D_invsqrt, graph)
X_1 = -re_norm * h + X_0 * (re_norm - 1)
# Concatenate Xt and X_1
Xt = tf.concat((Xt, X_1), 1)
# Append X_1 to Xt
Xt.append(X_1)
# Xi(x), i = 2...k
for _ in range(2, self._k):
h = unnLaplacian(X_1, D_invsqrt, graph)
X_i = -2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0
# Concatenate Xt and X_i
Xt = tf.concat((Xt, X_i), 1)
# Append X_i to Xt
Xt.append(X_i)
X_1, X_0 = X_i, X_1
# Create the concatenation
Xt = tf.concat(Xt, 1)
# linear projection
h = self.linear(Xt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment