"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "e4e0132972a09df8f73387c4ea48d2cd19178d2a"
Unverified Commit 8005978e authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] Grouped reversible residual connections for GNNs (#3842)

* Update

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
parent a3fd0595
...@@ -35,6 +35,7 @@ Conv Layers ...@@ -35,6 +35,7 @@ Conv Layers
~dgl.nn.pytorch.conv.TWIRLSUnfoldingAndAttention ~dgl.nn.pytorch.conv.TWIRLSUnfoldingAndAttention
~dgl.nn.pytorch.conv.GCN2Conv ~dgl.nn.pytorch.conv.GCN2Conv
~dgl.nn.pytorch.conv.HGTConv ~dgl.nn.pytorch.conv.HGTConv
~dgl.nn.pytorch.conv.GroupRevRes
Dense Conv Layers Dense Conv Layers
---------------------------------------- ----------------------------------------
......
...@@ -26,9 +26,10 @@ from .dotgatconv import DotGatConv ...@@ -26,9 +26,10 @@ from .dotgatconv import DotGatConv
from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
from .gcn2conv import GCN2Conv from .gcn2conv import GCN2Conv
from .hgtconv import HGTConv from .hgtconv import HGTConv
from .grouprevres import GroupRevRes
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv', __all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv',
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv',
'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv', 'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv', 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv'] 'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes']
"""Torch module for grouped reversible residual connections for GNNs"""
# pylint: disable= no-member, arguments-differ, invalid-name, C0116, R1728
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
class InvertibleCheckpoint(torch.autograd.Function):
r"""Extension of torch.autograd"""
@staticmethod
def forward(ctx, fn, fn_inverse, num_inputs, *inputs_and_weights):
ctx.fn = fn
ctx.fn_inverse = fn_inverse
ctx.weights = inputs_and_weights[num_inputs:]
inputs = inputs_and_weights[:num_inputs]
ctx.input_requires_grad = []
with torch.no_grad():
# Make a detached copy, which shares the storage
x = []
for element in inputs:
if isinstance(element, torch.Tensor):
x.append(element.detach())
ctx.input_requires_grad.append(element.requires_grad)
else:
x.append(element)
ctx.input_requires_grad.append(None)
# Detach the output, which then allows discarding the intermediary results
outputs = ctx.fn(*x).detach_()
# clear memory of input node features
inputs[1].storage().resize_(0)
# store for backward pass
ctx.inputs = [inputs]
ctx.outputs = [outputs]
return outputs
@staticmethod
def backward(ctx, *grad_outputs):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("InvertibleCheckpoint is not compatible with .grad(), \
please use .backward() if possible")
# retrieve input and output tensor nodes
if len(ctx.outputs) == 0:
raise RuntimeError("Trying to perform backward on the InvertibleCheckpoint \
for more than once.")
inputs = ctx.inputs.pop()
outputs = ctx.outputs.pop()
# reconstruct input node features
with torch.no_grad():
# inputs[0] is DGLGraph and inputs[1] is input node features
inputs_inverted = ctx.fn_inverse(*((inputs[0], outputs)+inputs[2:]))
# clear memory of outputs
outputs.storage().resize_(0)
x = inputs[1]
x.storage().resize_(int(np.prod(x.size())))
x.set_(inputs_inverted)
# compute gradients
with torch.set_grad_enabled(True):
detached_inputs = []
for i, element in enumerate(inputs):
if isinstance(element, torch.Tensor):
element = element.detach()
element.requires_grad = ctx.input_requires_grad[i]
detached_inputs.append(element)
detached_inputs = tuple(detached_inputs)
temp_output = ctx.fn(*detached_inputs)
filtered_detached_inputs = tuple(filter(lambda x: x.requires_grad, detached_inputs))
gradients = torch.autograd.grad(outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs)
input_gradients = []
i = 0
for rg in ctx.input_requires_grad:
if rg:
input_gradients.append(gradients[i])
i += 1
else:
input_gradients.append(None)
gradients = tuple(input_gradients) + gradients[-len(ctx.weights):]
return (None, None, None) + gradients
class GroupRevRes(nn.Module):
r"""Grouped reversible residual connections for GNNs, as introduced in
`Training Graph Neural Networks with 1000 Layers <https://arxiv.org/abs/2106.07476>`__
It uniformly partitions an input node feature :math:`X` into :math:`C` groups
:math:`X_1, X_2, \cdots, X_C` across the channel dimension. Besides, it makes
:math:`C` copies of the input GNN module :math:`f_{w1}, \cdots, f_{wC}`. In the
forward pass, each GNN module only takes the corresponding group of node features.
The output node representations :math:`X^{'}` are computed as follows.
.. math::
X_0^{'} = \sum_{i=2}^{C}X_i
X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\in\{1,\cdots,C\}
X^{'} = X_1^{'} \, \Vert \, \ldots \, \Vert \, X_C^{'}
where :math:`g` is the input graph, :math:`U` is arbitrary additional input arguments like
edge features, and :math:`\, \Vert \,` is concatenation.
Parameters
----------
gnn_module : nn.Module
GNN module for message passing. :attr:`GroupRevRes` will clone the module for
:attr:`groups`-1 number of times, yielding :attr:`groups` copies in total.
The input and output node representation size need to be the same. Its forward
function needs to take a DGLGraph and the associated input node features in order,
optionally followed by additional arguments like edge features.
groups : int, optional
The number of groups.
Examples
--------
>>> import dgl
>>> import torch
>>> import torch.nn as nn
>>> from dgl.nn import GraphConv, GroupRevRes
>>> class GNNLayer(nn.Module):
... def __init__(self, feats, dropout=0.2):
... super(GNNLayer, self).__init__()
... # Use BatchNorm and dropout to prevent gradient vanishing
... # In particular if you use a large number of GNN layers
... self.norm = nn.BatchNorm1d(feats)
... self.conv = GraphConv(feats, feats)
... self.dropout = nn.Dropout(dropout)
...
... def forward(self, g, x):
... x = self.norm(x)
... x = self.dropout(x)
... return self.conv(g, x)
>>> num_nodes = 5
>>> num_edges = 20
>>> feats = 32
>>> groups = 2
>>> g = dgl.rand_graph(num_nodes, num_edges)
>>> x = torch.randn(num_nodes, feats)
>>> conv = GNNLayer(feats // groups)
>>> model = GroupRevRes(conv, groups)
>>> out = model(g, x)
"""
def __init__(self, gnn_module, groups=2):
super(GroupRevRes, self).__init__()
self.gnn_modules = nn.ModuleList()
for i in range(groups):
if i == 0:
self.gnn_modules.append(gnn_module)
else:
self.gnn_modules.append(deepcopy(gnn_module))
self.groups = groups
def _forward(self, g, x, *args):
xs = torch.chunk(x, self.groups, dim=-1)
if len(args) == 0:
args_chunks = [()] * self.groups
else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args))
args_chunks = list(zip(*chunked_args))
y_in = sum(xs[1:])
ys = []
for i in range(self.groups):
y_in = xs[i] + self.gnn_modules[i](g, y_in, *args_chunks[i])
ys.append(y_in)
out = torch.cat(ys, dim=-1)
return out
def _inverse(self, g, y, *args):
ys = torch.chunk(y, self.groups, dim=-1)
if len(args) == 0:
args_chunks = [()] * self.groups
else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args))
args_chunks = list(zip(*chunked_args))
xs = []
for i in range(self.groups-1, -1, -1):
if i != 0:
y_in = ys[i-1]
else:
y_in = sum(xs)
x = ys[i] - self.gnn_modules[i](g, y_in, *args_chunks[i])
xs.append(x)
x = torch.cat(xs[::-1], dim=-1)
return x
def forward(self, g, x, *args):
r"""Apply the GNN module with grouped reversible residual connection.
Parameters
----------
g : DGLGraph
The graph.
x : torch.Tensor
The input feature of shape :math:`(N, D_{in})`, where :math:`D_{in}` is size
of input feature, :math:`N` is the number of nodes.
args
Additional arguments to pass to :attr:`gnn_module`.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{in})`.
"""
args = (g, x) + args
y = InvertibleCheckpoint.apply(
self._forward,
self._inverse,
len(args),
*(args + tuple([p for p in self.parameters() if p.requires_grad])))
return y
...@@ -1308,7 +1308,7 @@ def test_hgt(idtype, in_size, num_heads): ...@@ -1308,7 +1308,7 @@ def test_hgt(idtype, in_size, num_heads):
etype = th.tensor([i % num_etypes for i in range(g.num_edges())]).to(dev) etype = th.tensor([i % num_etypes for i in range(g.num_edges())]).to(dev)
ntype = th.tensor([i % num_ntypes for i in range(g.num_nodes())]).to(dev) ntype = th.tensor([i % num_ntypes for i in range(g.num_nodes())]).to(dev)
x = th.randn(g.num_nodes(), in_size).to(dev) x = th.randn(g.num_nodes(), in_size).to(dev)
m = nn.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes).to(dev) m = nn.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes).to(dev)
y = m(g, x, ntype, etype) y = m(g, x, ntype, etype)
...@@ -1329,3 +1329,17 @@ def test_hgt(idtype, in_size, num_heads): ...@@ -1329,3 +1329,17 @@ def test_hgt(idtype, in_size, num_heads):
assert sorted_y.shape == (g.num_nodes(), head_size * num_heads) assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
# TODO(minjie): enable the following check # TODO(minjie): enable the following check
#assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4) #assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4)
@parametrize_dtype
def test_group_rev_res(idtype):
dev = F.ctx()
num_nodes = 5
num_edges = 20
feats = 32
groups = 2
g = dgl.rand_graph(num_nodes, num_edges).to(dev)
h = th.randn(num_nodes, feats).to(dev)
conv = nn.GraphConv(feats // groups, feats // groups)
model = nn.GroupRevRes(conv, groups).to(dev)
model(g, h)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment