Unverified Commit ddb5d804 authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

[Refactor] Break NN modules into files (#859)

* break nn modules into files

* break mxnet nn modules

* fix lint

* fix lint
parent 4cd5c19e
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from torch import nn
from torch.nn import functional as F
from .... import function as fn
class SAGEConv(nn.Module):
r"""GraphSAGE layer from paper `Inductive Representation Learning on
Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__.
.. math::
h_{\mathcal{N}(i)}^{(l+1)} & = \mathrm{aggregate}
\left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)
h_{i}^{(l+1)} & = \sigma \left(W \cdot \mathrm{concat}
(h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1} + b) \right)
h_{i}^{(l+1)} & = \mathrm{norm}(h_{i}^{l})
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
feat_drop : float
Dropout rate on features, default: ``0``.
aggregator_type : str
Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
norm : callable activation function/layer or None, optional
If not None, applies normalization to the updated node features.
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
"""
def __init__(self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
self.fc_pool = nn.Linear(in_feats, in_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(in_feats, in_feats, batch_first=True)
if aggregator_type != 'gcn':
self.fc_self = nn.Linear(in_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(in_feats, out_feats, bias=bias)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def _lstm_reducer(self, nodes):
"""LSTM reducer
NOTE(zihao): lstm reducer with default schedule (degree bucketing)
is slow, we could accelerate this with degree padding in the future.
"""
m = nodes.mailbox['m'] # (B, L, D)
batch_size = m.shape[0]
h = (m.new_zeros((1, batch_size, self._in_feats)),
m.new_zeros((1, batch_size, self._in_feats)))
_, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)}
def forward(self, graph, feat):
r"""Compute GraphSAGE layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
graph = graph.local_var()
feat = self.feat_drop(feat)
h_self = feat
if self._aggre_type == 'mean':
graph.ndata['h'] = feat
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.ndata['neigh']
elif self._aggre_type == 'gcn':
graph.ndata['h'] = feat
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().float()
degs = degs.to(feat.device)
h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'pool':
graph.ndata['h'] = F.relu(self.fc_pool(feat))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.ndata['neigh']
elif self._aggre_type == 'lstm':
graph.ndata['h'] = feat
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
h_neigh = graph.ndata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self.norm is not None:
rst = self.norm(rst)
return rst
"""Torch Module for Simplifying Graph Convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn
from .... import function as fn
class SGConv(nn.Module):
r"""Simplifying Graph Convolution layer from paper `Simplifying Graph
Convolutional Networks <https://arxiv.org/pdf/1902.07153.pdf>`__.
.. math::
H^{l+1} = (\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2})^K H^{l} \Theta^{l}
Parameters
----------
in_feats : int
Number of input features.
out_feats : int
Number of output features.
k : int
Number of hops :math:`K`. Defaults:``1``.
cached : bool
If True, the module would cache
.. math::
(\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}})^K X\Theta
at the first forward call. This parameter should only be set to
``True`` in Transductive Learning setting.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
norm : callable activation function/layer or None, optional
If not None, applies normalization to the updated node features.
"""
def __init__(self,
in_feats,
out_feats,
k=1,
cached=False,
bias=True,
norm=None):
super(SGConv, self).__init__()
self.fc = nn.Linear(in_feats, out_feats, bias=bias)
self._cached = cached
self._cached_h = None
self._k = k
self.norm = norm
def forward(self, graph, feat):
r"""Compute Simplifying Graph Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
Notes
-----
If ``cache`` is se to True, ``feat`` and ``graph`` should not change during
training, or you will get wrong results.
"""
graph = graph.local_var()
if self._cached_h is not None:
feat = self._cached_h
else:
# compute normalization
degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5)
norm[th.isinf(norm)] = 0
norm = norm.to(feat.device).unsqueeze(1)
# compute (D^-1 A D) X
for _ in range(self._k):
feat = feat * norm
graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'),
fn.sum('m', 'h'))
feat = graph.ndata.pop('h')
feat = feat * norm
if self.norm is not None:
feat = self.norm(feat)
# cache feature
if self._cached:
self._cached_h = feat
return self.fc(feat)
"""Torch Module for Topology Adaptive Graph Convolutional layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn
from .... import function as fn
class TAGConv(nn.Module):
r"""Topology Adaptive Graph Convolutional layer from paper `Topology
Adaptive Graph Convolutional Networks <https://arxiv.org/pdf/1710.10370.pdf>`__.
.. math::
\mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k},
where :math:`\mathbf{A}` denotes the adjacency matrix and
:math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix.
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
k: int, optional
Number of hops :math: `k`. (default: 2)
bias: bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
Attributes
----------
lin : torch.Module
The learnable linear module.
"""
def __init__(self,
in_feats,
out_feats,
k=2,
bias=True,
activation=None):
super(TAGConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._k = k
self._activation = activation
self.lin = nn.Linear(in_feats * (self._k + 1), out_feats, bias=bias)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.lin.weight, gain=gain)
def forward(self, graph, feat):
r"""Compute topology adaptive graph convolution.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
graph = graph.local_var()
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device)
#D-1/2 A D -1/2 X
fstack = [feat]
for _ in range(self._k):
rst = fstack[-1] * norm
graph.ndata['h'] = rst
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.ndata['h']
rst = rst * norm
fstack.append(rst)
rst = self.lin(th.cat(fstack, dim=-1))
if self._activation is not None:
rst = self._activation(rst)
return rst
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#pylint: disable=no-member, invalid-name #pylint: disable=no-member, invalid-name
import torch as th import torch as th
from torch import nn
def matmul_maybe_select(A, B): def matmul_maybe_select(A, B):
"""Perform Matrix multiplication C = A * B but A could be an integer id vector. """Perform Matrix multiplication C = A * B but A could be an integer id vector.
...@@ -86,3 +88,16 @@ def bmm_maybe_select(A, B, index): ...@@ -86,3 +88,16 @@ def bmm_maybe_select(A, B, index):
else: else:
BB = B.index_select(0, index) BB = B.index_select(0, index)
return th.bmm(A.unsqueeze(1), BB).squeeze() return th.bmm(A.unsqueeze(1), BB).squeeze()
# pylint: disable=W0235
class Identity(nn.Module):
"""A placeholder identity operator that is argument-insensitive.
(Identity has already been supported by PyTorch 1.2, we will directly
import torch.nn.Identity in the future)
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
"""Return input"""
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment