Unverified Commit dda103d9 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

Fix bug in EdgeConv implementation (#2669)

* Update edgeconv.py

* update edgeconv to builtin

* add explanation on why ReLU is not there

* remove redundant code
parent d93a9759
......@@ -19,12 +19,18 @@ class EdgeConv(nn.Block):
<https://arxiv.org/pdf/1801.07829>`__". Can be described as follows:
.. math::
h_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} \mathrm{ReLU}(
h_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} (
\Theta \cdot (h_j^{(l)} - h_i^{(l)}) + \Phi \cdot h_i^{(l)})
where :math:`\mathcal{N}(i)` is the neighbor of :math:`i`.
:math:`\Theta` and :math:`\Phi` are linear layers.
.. note::
The original formulation includes a ReLU inside the maximum operator.
This is equivalent to first applying a maximum operator then applying
the ReLU.
Parameters
----------
in_feat : int
......@@ -114,13 +120,6 @@ class EdgeConv(nn.Block):
if batch_norm:
self.bn = nn.BatchNorm(in_channels=out_feat)
def message(self, edges):
r"""The message computation function
"""
theta_x = self.theta(edges.dst['x'] - edges.src['x'])
phi_x = self.phi(edges.src['x'])
return {'e': theta_x + phi_x}
def set_allow_zero_in_degree(self, set_value):
r"""
......@@ -182,10 +181,13 @@ class EdgeConv(nn.Block):
h_src, h_dst = expand_as_pair(h, g)
g.srcdata['x'] = h_src
g.dstdata['x'] = h_dst
g.apply_edges(fn.v_sub_u('x', 'x', 'theta'))
g.edata['theta'] = self.theta(g.edata['theta'])
g.dstdata['phi'] = self.phi(g.dstdata['x'])
if not self.batch_norm:
g.update_all(self.message, fn.max('e', 'x'))
g.update_all(fn.e_add_v('theta', 'phi', 'e'), fn.max('e', 'x'))
else:
g.apply_edges(self.message)
g.apply_edges(fn.e_add_v('theta', 'phi', 'e'))
g.edata['e'] = self.bn(g.edata['e'])
g.update_all(fn.copy_e('e', 'm'), fn.max('m', 'x'))
return g.dstdata['x']
......@@ -18,12 +18,18 @@ class EdgeConv(nn.Module):
<https://arxiv.org/pdf/1801.07829>`__". Can be described as follows:
.. math::
h_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} \mathrm{ReLU}(
h_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} (
\Theta \cdot (h_j^{(l)} - h_i^{(l)}) + \Phi \cdot h_i^{(l)})
where :math:`\mathcal{N}(i)` is the neighbor of :math:`i`.
:math:`\Theta` and :math:`\Phi` are linear layers.
.. note::
The original formulation includes a ReLU inside the maximum operator.
This is equivalent to first applying a maximum operator then applying
the ReLU.
Parameters
----------
in_feat : int
......@@ -105,13 +111,6 @@ class EdgeConv(nn.Module):
if batch_norm:
self.bn = nn.BatchNorm1d(out_feat)
def message(self, edges):
"""The message computation function.
"""
theta_x = self.theta(edges.dst['x'] - edges.src['x'])
phi_x = self.phi(edges.src['x'])
return {'e': theta_x + phi_x}
def set_allow_zero_in_degree(self, set_value):
r"""
......@@ -173,10 +172,13 @@ class EdgeConv(nn.Module):
h_src, h_dst = expand_as_pair(feat, g)
g.srcdata['x'] = h_src
g.dstdata['x'] = h_dst
g.apply_edges(fn.v_sub_u('x', 'x', 'theta'))
g.edata['theta'] = self.theta(g.edata['theta'])
g.dstdata['phi'] = self.phi(g.dstdata['x'])
if not self.batch_norm:
g.update_all(self.message, fn.max('e', 'x'))
g.update_all(fn.e_add_v('theta', 'phi', 'e'), fn.max('e', 'x'))
else:
g.apply_edges(self.message)
g.apply_edges(fn.e_add_v('theta', 'phi', 'e'))
# Although the official implementation includes a per-edge
# batch norm within EdgeConv, I choose to replace it with a
# global batch norm for a number of reasons:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment