Unverified Commit 5d3f470b authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] DGL Pooling modules (#669)

* removal doc

* glob

* upd

* rm knn

* add softmax

* upd

* upd

* add broadcast and s2s

* optimize max_on

* forsaken changes to heterograph

* upd

* upd

* upd

* upd

* upd

* bugfix

* upd

* upd

* upd

* upd

* format upd

* upd format

* upd doc

* upd

* import order

* upd

* rm warnings

* fix

* upd test

* upd

* upd

* fix device

* upd

* upd

* upd

* upd

* remove 1.1

* upd

* trigger

* trigger

* add more tests

* fix device

* upd

* upd

* refactor

* fix?

* fix

* upd docstring

* refactor

* upd

* fix

* upd

* upd

* upd

* fix

* upd docs

* add shape

* refactor & upd doc

* upd doc

* upd
parent c3516f1a
......@@ -37,3 +37,9 @@ Graph Readout
mean_edges
max_nodes
max_edges
topk_nodes
topk_edges
softmax_nodes
softmax_edges
broadcast_nodes
broadcast_edges
......@@ -11,3 +11,18 @@ dgl.nn.mxnet.conv
.. autoclass:: dgl.nn.mxnet.conv.GraphConv
:members: weight, bias, forward
:show-inheritance:
dgl.nn.mxnet.glob
-----------------
.. automodule:: dgl.nn.mxnet.glob
:members:
dgl.nn.mxnet.softmax
--------------------
.. automodule:: dgl.nn.mxnet.softmax
.. autoclass:: dgl.nn.mxnet.softmax.EdgeSoftmax
:members: forward
:show-inheritance:
......@@ -12,6 +12,43 @@ dgl.nn.pytorch.conv
:members: weight, bias, forward, reset_parameters
:show-inheritance:
dgl.nn.pytorch.glob
-------------------
.. automodule:: dgl.nn.pytorch.glob
.. autoclass:: dgl.nn.pytorch.glob.SumPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.AvgPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.MaxPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.SortPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.GlobalAttentionPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.Set2Set
:members: forward
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.SetTransformerEncoder
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.SetTransformerDecoder
:members:
:show-inheritance:
dgl.nn.pytorch.softmax
----------------------
......
......@@ -74,6 +74,21 @@ def tensor(data, dtype=None):
"""
pass
def as_scalar(data):
"""Returns a scalar whose value is copied from this array.
Parameters
----------
data : Tensor
The input data
Returns
-------
scalar
The scalar value in the tensor.
"""
pass
def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend.
......@@ -293,6 +308,21 @@ def sum(input, dim):
"""
pass
def reduce_sum(input):
"""Returns the sum of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def mean(input, dim):
"""Reduce average the input tensor along the given dim.
......@@ -310,6 +340,21 @@ def mean(input, dim):
"""
pass
def reduce_mean(input):
"""Returns the average of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def max(input, dim):
"""Reduce max the input tensor along the given dim.
......@@ -327,6 +372,121 @@ def max(input, dim):
"""
pass
def reduce_max(input):
"""Returns the max of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def min(input, dim):
"""Reduce min the input tensor along the given dim.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The reduce dim.
Returns
-------
Tensor
A framework-specific tensor.
"""
pass
def reduce_min(input):
"""Returns the min of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def argsort(input, dim, descending):
"""Return the indices that would sort the input along the given dim.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The dim to sort along.
descending : bool
Controls the sorting order (False: ascending, True: descending)
Returns
-------
Tensor
A framework-specific tensor.
"""
def topk(input, k, dim, descending=True):
"""Return the k largest elements of the given input tensor along the given dimension.
If descending is False then the k smallest elements are returned.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The dim to sort along.
descending : bool
Controls whether to return largest/smallest elements.
"""
pass
def exp(input):
"""Returns a new tensor with the exponential of the elements of the input tensor `input`.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
The output tensor.
"""
pass
def softmax(input, dim=-1):
"""Apply the softmax function on given dimension.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The dimension along which to compute softmax.
Returns
-------
Tensor
The output tensor.
"""
pass
def cat(seq, dim):
"""Concat the sequence of tensors in the given dimension.
......@@ -381,6 +541,25 @@ def split(input, sizes_or_sections, dim):
"""
pass
def repeat(input, repeats, dim):
"""Repeats elements of an array.
Parameters
----------
input : Tensor
Input data array
repeats : int
The number of repetitions for each element
dim : int
The dim along which to repeat values.
Returns
-------
Tensor
The obtained tensor.
"""
pass
def gather_row(data, row_index):
"""Slice out the data given the row index.
......@@ -398,6 +577,41 @@ def gather_row(data, row_index):
"""
pass
def slice_axis(data, axis, begin, end):
"""Slice along a given axis.
Returns an array slice along a given axis starting from :attr:`begin` index to :attr:`end` index.
Parameters
----------
data : Tensor
The data tensor.
axis : int
The axis along to slice the tensor.
begin : int
Indicates the begin index.
end : int
Indicates the end index.
Returns:
--------
Tensor
The sliced tensor.
"""
pass
def take(data, indices, dim):
"""Takes elements from an input array along the given dim.
Parameters
----------
data : Tensor
The data tensor.
indices : Tensor
The indices tensor.
dim : Tensor
The dimension to gather along.
"""
pass
def narrow_row(x, start, stop):
"""Narrow down the tensor along the first dimension.
......@@ -563,6 +777,50 @@ def ones(shape, dtype, ctx):
"""
pass
def pad_packed_tensor(input, lengths, value, l_min=None):
"""Pads a packed batch of variable length tensors with given value.
Parameters
----------
input : Tensor
The input tensor with shape :math:`(N, *)`
lengths : list or tensor
The array of tensor lengths (of the first dimension) :math:`L`.
It should satisfy :math:`\sum_{i=1}^{B}L_i = N`,
where :math:`B` is the length of :math:`L`.
value : float
The value to fill in the tensor.
l_min : int or None, defaults to None.
The minimum length each tensor need to be padded to, if set to None,
then there is no minimum length requirement.
Returns
-------
Tensor
The obtained tensor with shape :math:`(B, \max(\max_i(L_i), l_{min}), *)`
"""
pass
def pack_padded_tensor(input, lengths):
"""Packs a tensor containing padded sequence of variable length.
Parameters
----------
input : Tensor
The input tensor with shape :math:`(B, L, *)`, where :math:`B` is
the batch size and :math:`L` is the maximum length of the batch.
lengths : list or tensor
The array of tensor lengths (of the first dimension) :math:`L`.
:math:`\max_i(L_i)` should equal :math:`L`.
Returns
-------
Tensor
The obtained tensor with shape :math:`(N, *)` where
:math:`N = \sum_{i=1}^{B}L_i`
"""
pass
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
"""Computes the sum along segments of a tensor.
......
......@@ -6,6 +6,7 @@ import numpy as np
import mxnet as mx
import mxnet.ndarray as nd
import numbers
import builtins
from ... import ndarray as dglnd
from ... import kernel as K
......@@ -38,6 +39,9 @@ def tensor(data, dtype=None):
dtype = np.float32
return nd.array(data, dtype=dtype)
def as_scalar(data):
return data.asscalar()
def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend.
......@@ -112,12 +116,41 @@ def copy_to(input, ctx):
def sum(input, dim):
return nd.sum(input, axis=dim)
def reduce_sum(input):
return input.sum()
def mean(input, dim):
return nd.mean(input, axis=dim)
def reduce_mean(input):
return input.mean()
def max(input, dim):
return nd.max(input, axis=dim)
def reduce_max(input):
return input.max()
def min(input, dim):
return nd.min(input, axis=dim)
def reduce_min(input):
return input.min()
def topk(input, k, dim, descending=True):
return nd.topk(input, axis=dim, k=k, ret_typ='value', is_ascend=not descending)
def argsort(input, dim, descending):
idx = nd.argsort(input, dim, is_ascend=not descending)
idx = nd.cast(idx, dtype='int64')
return idx
def exp(input):
return nd.exp(input)
def softmax(input, dim=-1):
return nd.softmax(input, axis=dim)
def cat(seq, dim):
return nd.concat(*seq, dim=dim)
......@@ -143,6 +176,9 @@ def split(x, sizes_or_sections, dim):
else:
return nd.split(x, sizes_or_sections, axis=dim)
def repeat(input, repeats, dim):
return nd.repeat(input, repeats, axis=dim)
def gather_row(data, row_index):
# MXNet workaround for empty row index
if len(row_index) == 0:
......@@ -153,6 +189,17 @@ def gather_row(data, row_index):
else:
return data[row_index,]
def slice_axis(data, axis, begin, end):
dim = data.shape[axis]
if begin < 0:
begin += dim
if end <= 0:
end += dim
return nd.slice_axis(data, axis, begin, end)
def take(data, indices, dim):
return nd.take(data, indices, dim)
def narrow_row(data, start, stop):
return data[start:stop]
......@@ -181,6 +228,35 @@ def zeros_like(input):
def ones(shape, dtype, ctx):
return nd.ones(shape, dtype=dtype, ctx=ctx)
def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape
if isinstance(lengths, nd.NDArray):
max_len = as_scalar(input.max())
else:
max_len = builtins.max(lengths)
if l_min is not None:
max_len = builtins.max(max_len, l_min)
batch_size = len(lengths)
ctx = input.context
dtype = input.dtype
x = nd.full((batch_size * max_len, *old_shape[1:]), value, ctx=ctx, dtype=dtype)
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = nd.array(index, ctx=ctx)
return scatter_row(x, index, input).reshape(batch_size, max_len, *old_shape[1:])
def pack_padded_tensor(input, lengths):
batch_size, max_len = input.shape[:2]
ctx = input.context
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = nd.array(index, ctx=ctx)
return gather_row(input.reshape(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
# TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment sum on first dimension'
......
......@@ -22,6 +22,11 @@ def cpu():
def tensor(data, dtype=None):
return np.array(data, dtype)
def as_scalar(data):
if data.dim() > 1:
raise ValueError('The data must have shape (1,).')
return data[0]
def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend.
......@@ -75,9 +80,46 @@ def copy_to(input, ctx):
def sum(input, dim):
return np.sum(input, axis=dim)
def reduce_sum(input):
dtype = input.dtype
return np.array(input.sum(), dtype=dtype)
def mean(input, dim):
return np.mean(input, axis=dim)
def reduce_mean(input):
dtype = input.dtype
return np.array(input.mean(), dtype=dtype)
def max(input, dim):
return np.max(input, axis=dim)
def reduce_max(input):
dtype = input.dtype
return np.array(input.max(), dtype=dtype)
def min(input, dim):
return np.min(input, axis=dim)
def reduce_min(input):
dtype = input.dtype
return np.array(input.min(), dtype=dtype)
def argsort(input, dim, descending):
if descending:
return np.argsort(-input, axis=dim)
return np.argsort(input, axis=dim)
def exp(input):
return np.exp(input)
def softmax(input, dim=-1):
max_val = input.max(axis=dim)
minus_max = input - np.expand_dims(max_val, axis=dim)
exp_val = np.exp(minus_max)
sum_val = np.sum(exp_val, axis=dim)
return exp_val / np.expand_dims(sum_val, axis=dim)
def cat(seq, dim):
return np.concatenate(seq, axis=dim)
......@@ -92,9 +134,20 @@ def split(input, sizes_or_sections, dim):
idx = np.cumsum(sizes_or_sections)[0:-1]
return np.split(input, idx, axis=dim)
def repeat(input, repeats, dim):
return np.repeat(input, repeats, axis=dim)
def gather_row(data, row_index):
return data[row_index]
def slice_axis(data, axis, begin, end):
if begin >= end:
raise IndexError("Begin index ({}) equals or greater than end index ({})".format(begin, end))
return np.take(data, np.arange(begin, end), axis=axis)
def take(data, indices, dim):
return np.take(data, indices, axis=dim)
def scatter_row(data, row_index, value):
# NOTE: inplace instead of out-place
data[row_index] = value
......
......@@ -3,6 +3,7 @@ from __future__ import absolute_import
from distutils.version import LooseVersion
import torch as th
import builtins
from torch.utils import dlpack
from ... import ndarray as nd
......@@ -26,6 +27,9 @@ def cpu():
def tensor(data, dtype=None):
return th.tensor(data, dtype=dtype)
def as_scalar(data):
return data.item()
def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend.
......@@ -90,13 +94,41 @@ def copy_to(input, ctx):
def sum(input, dim):
return th.sum(input, dim=dim)
def reduce_sum(input):
return input.sum()
def mean(input, dim):
return th.mean(input, dim=dim)
def reduce_mean(input):
return input.mean()
def max(input, dim):
# NOTE: the second argmax array is not returned
return th.max(input, dim=dim)[0]
def reduce_max(input):
return input.max()
def min(input, dim):
# NOTE: the second argmin array is not returned
return th.min(input, dim=dim)[0]
def reduce_min(input):
return input.min()
def argsort(input, dim, descending):
return th.argsort(input, dim=dim, descending=descending)
def topk(input, k, dim, descending=True):
return th.topk(input, k, dim, largest=descending)[0]
def exp(input):
return th.exp(input)
def softmax(input, dim=-1):
return th.softmax(input, dim=dim)
def cat(seq, dim):
return th.cat(seq, dim=dim)
......@@ -106,9 +138,29 @@ def stack(seq, dim):
def split(input, sizes_or_sections, dim):
return th.split(input, sizes_or_sections, dim)
def repeat(input, repeats, dim):
# return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1
if dim < 0:
dim += input.dim()
return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1)
def gather_row(data, row_index):
return th.index_select(data, 0, row_index)
def slice_axis(data, axis, begin, end):
dim = data.shape[axis]
if begin < 0:
begin += dim
if end <= 0:
end += dim
if begin >= end:
raise IndexError("Begin index ({}) equals or greater than end index ({})".format(begin, end))
return th.index_select(data, axis, th.arange(begin, end, device=data.device))
def take(data, indices, dim):
new_shape = data.shape[:dim] + indices.shape + data.shape[dim+1:]
return th.index_select(data, dim, indices.view(-1)).view(new_shape)
def narrow_row(x, start, stop):
return x[start:stop]
......@@ -136,6 +188,35 @@ def zeros_like(input):
def ones(shape, dtype, ctx):
return th.ones(shape, dtype=dtype, device=ctx)
def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape
if isinstance(lengths, th.Tensor):
max_len = as_scalar(lengths.max())
else:
max_len = builtins.max(lengths)
if l_min is not None:
max_len = builtins.max(max_len, l_min)
batch_size = len(lengths)
device = input.device
x = input.new(batch_size * max_len, *old_shape[1:])
x.fill_(value)
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = th.tensor(index).to(device)
return scatter_row(x, index, input).view(batch_size, max_len, *old_shape[1:])
def pack_padded_tensor(input, lengths):
batch_size, max_len = input.shape[:2]
device = input.device
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = th.tensor(index).to(device)
return gather_row(input.view(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input)
seg_id = seg_id.view((-1,) + (1,) * (input.dim() - 1)).expand_as(input)
......
This diff is collapsed.
"""Package for mxnet-specific NN modules."""
from .conv import *
from .glob import *
from .softmax import *
"""MXNet modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, C0103, W0235
import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon import nn
from ... import BatchedDGLGraph
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'GlobalAttentionPooling', 'Set2Set']
class SumPooling(nn.Block):
r"""Apply sum pooling over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(SumPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute sum pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = sum_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'SumPooling()'
class AvgPooling(nn.Block):
r"""Apply average pooling over the nodes in the graph.
.. math::
r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(AvgPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute average pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = mean_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'AvgPooling()'
class MaxPooling(nn.Block):
r"""Apply max pooling over the nodes in the graph.
.. math::
r^{(i)} = \max_{k=1}^{N_i} \left( x^{(i)}_k \right)
"""
def __init__(self):
super(MaxPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute max pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = max_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'MaxPooling()'
class SortPooling(nn.Block):
r"""Apply Sort Pooling (`An End-to-End Deep Learning Architecture for Graph Classification
<https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__) over the nodes in the graph.
Parameters
----------
k : int
The number of nodes to hold for each graph.
"""
def __init__(self, k):
super(SortPooling, self).__init__()
self.k = k
def forward(self, feat, graph):
r"""Compute sort pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(k * D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, k * D)`.
"""
# Sort the feature of each node in ascending order.
with graph.local_scope():
feat = feat.sort(axis=-1)
graph.ndata['h'] = feat
# Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k)[0].reshape(
-1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph):
return ret
else:
return ret.squeeze(axis=0)
def __repr__(self):
return 'SortPooling(k={})'.format(self.k)
class GlobalAttentionPooling(nn.Block):
r"""Apply Global Attention Pooling (`Gated Graph Sequence Neural Networks
<https://arxiv.org/abs/1511.05493.pdf>`__) over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate}
\left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)
Parameters
----------
gate_nn : gluon.nn.Block
A neural network that computes attention scores for each feature.
feat_nn : gluon.nn.Block, optional
A neural network applied to each feature before combining them
with attention scores.
"""
def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__()
with self.name_scope():
self.gate_nn = gate_nn
self.feat_nn = feat_nn
self._reset_parameters()
def _reset_parameters(self):
self.gate_nn.initialize(mx.init.Xavier())
if self.feat_nn:
self.feat_nn.initialize(mx.init.Xavier())
def forward(self, feat, graph):
r"""Compute global attention pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate
gate = softmax_nodes(graph, 'gate')
graph.ndata['r'] = feat * gate
readout = sum_nodes(graph, 'r')
return readout
class Set2Set(nn.Block):
r"""Apply Set2Set (`Order Matters: Sequence to sequence for sets
<https://arxiv.org/pdf/1511.06391.pdf>`__) over the nodes in the graph.
For each individual graph in the batch, set2set computes
.. math::
q_t &= \mathrm{LSTM} (q^*_{t-1})
\alpha_{i,t} &= \mathrm{softmax}(x_i \cdot q_t)
r_t &= \sum_{i=1}^N \alpha_{i,t} x_i
q^*_t &= q_t \Vert r_t
for this graph.
Parameters
----------
input_dim : int
Size of each input sample
n_iters : int
Number of iterations.
n_layers : int
Number of recurrent layers.
"""
def __init__(self, input_dim, n_iters, n_layers):
super(Set2Set, self).__init__()
self.input_dim = input_dim
self.output_dim = 2 * input_dim
self.n_iters = n_iters
self.n_layers = n_layers
with self.name_scope():
self.lstm = gluon.rnn.LSTM(
self.input_dim, num_layers=n_layers, input_size=self.output_dim)
self._reset_parameters()
def _reset_parameters(self):
self.lstm.initialize(mx.init.Xavier())
def forward(self, feat, graph):
r"""Compute set2set pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
"""
with graph.local_scope():
batch_size = 1
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
h = (nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context),
nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context))
q_star = nd.zeros((batch_size, self.output_dim), ctx=feat.context)
for _ in range(self.n_iters):
q, h = self.lstm(q_star.expand_dims(axis=0), h)
q = q.reshape((batch_size, self.input_dim))
e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True)
graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r')
if readout.ndim == 1: # graph is not a BatchedDGLGraph
readout = readout.expand_dims(0)
q_star = nd.concat(q, readout, dim=-1)
if isinstance(graph, BatchedDGLGraph):
return q_star
else:
return q_star.squeeze(axis=0)
def __repr__(self):
summary = 'Set2Set('
summary += 'in={}, out={}, ' \
'n_iters={}, n_layers={}'.format(self.input_dim,
self.output_dim,
self.n_iters,
self.n_layers)
summary += ')'
return summary
"""Package for pytorch-specific NN modules."""
from .conv import *
from .glob import *
from .softmax import *
This diff is collapsed.
import dgl
import backend as F
import networkx as nx
def test_simple_readout():
g1 = dgl.DGLGraph()
......@@ -57,8 +58,192 @@ def test_simple_readout():
max_bg_e = dgl.max_edges(g, 'x')
assert F.allclose(s, F.stack([se1, F.zeros(5)], 0))
assert F.allclose(m, F.stack([me1, F.zeros(5)], 0))
assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
# TODO(zihao): fix -inf issue
# assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
def test_topk_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(14))
feat0 = F.randn((g0.number_of_nodes(), 10))
g0.ndata['x'] = feat0
# to test the case where k > number of nodes.
dgl.topk_nodes(g0, 'x', 20, idx=-1)
# test correctness
val, indices = dgl.topk_nodes(g0, 'x', 5, idx=-1)
ground_truth = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 9, 10), 0, True)[:5], (5,))
assert F.allclose(ground_truth, indices)
g0.ndata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(12))
feat1 = F.randn((g1.number_of_nodes(), 10))
bg = dgl.batch([g0, g1])
bg.ndata['x'] = F.cat([feat0, feat1], 0)
# to test the case where k > number of nodes.
dgl.topk_nodes(bg, 'x', 16, idx=1)
# test correctness
val, indices = dgl.topk_nodes(bg, 'x', 6, descending=False, idx=0)
ground_truth_0 = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 0, 1), 0, False)[:6], (6,))
ground_truth_1 = F.reshape(
F.argsort(F.slice_axis(feat1, -1, 0, 1), 0, False)[:6], (6,))
ground_truth = F.stack([ground_truth_0, ground_truth_1], 0)
assert F.allclose(ground_truth, indices)
# test idx=None
val, indices = dgl.topk_nodes(bg, 'x', 6, descending=True)
assert F.allclose(val, F.stack([F.topk(feat0, 6, 0), F.topk(feat1, 6, 0)], 0))
def test_topk_edges():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(14))
feat0 = F.randn((g0.number_of_edges(), 10))
g0.edata['x'] = feat0
# to test the case where k > number of edges.
dgl.topk_edges(g0, 'x', 30, idx=-1)
# test correctness
val, indices = dgl.topk_edges(g0, 'x', 7, idx=-1)
ground_truth = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 9, 10), 0, True)[:7], (7,))
assert F.allclose(ground_truth, indices)
g0.edata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(12))
feat1 = F.randn((g1.number_of_edges(), 10))
bg = dgl.batch([g0, g1])
bg.edata['x'] = F.cat([feat0, feat1], 0)
# to test the case where k > number of edges.
dgl.topk_edges(bg, 'x', 33, idx=1)
# test correctness
val, indices = dgl.topk_edges(bg, 'x', 4, descending=False, idx=0)
ground_truth_0 = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 0, 1), 0, False)[:4], (4,))
ground_truth_1 = F.reshape(
F.argsort(F.slice_axis(feat1, -1, 0, 1), 0, False)[:4], (4,))
ground_truth = F.stack([ground_truth_0, ground_truth_1], 0)
assert F.allclose(ground_truth, indices)
# test idx=None
val, indices = dgl.topk_edges(bg, 'x', 6, descending=True)
assert F.allclose(val, F.stack([F.topk(feat0, 6, 0), F.topk(feat1, 6, 0)], 0))
def test_softmax_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(9))
feat0 = F.randn((g0.number_of_nodes(), 10))
g0.ndata['x'] = feat0
ground_truth = F.softmax(feat0, dim=0)
assert F.allclose(dgl.softmax_nodes(g0, 'x'), ground_truth)
g0.ndata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5))
g2 = dgl.DGLGraph(nx.path_graph(3))
g3 = dgl.DGLGraph()
g4 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g0, g1, g2, g3, g4])
feat1 = F.randn((g1.number_of_nodes(), 10))
feat2 = F.randn((g2.number_of_nodes(), 10))
feat4 = F.randn((g4.number_of_nodes(), 10))
bg.ndata['x'] = F.cat([feat0, feat1, feat2, feat4], 0)
ground_truth = F.cat([
F.softmax(feat0, 0),
F.softmax(feat1, 0),
F.softmax(feat2, 0),
F.softmax(feat4, 0)
], 0)
assert F.allclose(dgl.softmax_nodes(bg, 'x'), ground_truth)
def test_softmax_edges():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((g0.number_of_edges(), 10))
g0.edata['x'] = feat0
ground_truth = F.softmax(feat0, dim=0)
assert F.allclose(dgl.softmax_edges(g0, 'x'), ground_truth)
g0.edata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5))
g2 = dgl.DGLGraph(nx.path_graph(3))
g3 = dgl.DGLGraph()
g4 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g0, g1, g2, g3, g4])
feat1 = F.randn((g1.number_of_edges(), 10))
feat2 = F.randn((g2.number_of_edges(), 10))
feat4 = F.randn((g4.number_of_edges(), 10))
bg.edata['x'] = F.cat([feat0, feat1, feat2, feat4], 0)
ground_truth = F.cat([
F.softmax(feat0, 0),
F.softmax(feat1, 0),
F.softmax(feat2, 0),
F.softmax(feat4, 0)
], 0)
assert F.allclose(dgl.softmax_edges(bg, 'x'), ground_truth)
def test_broadcast_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,))
ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0)
assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth)
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(3))
g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,))
feat2 = F.randn((40,))
feat3 = F.randn((40,))
ground_truth = F.stack(
[feat0] * g0.number_of_nodes() +\
[feat1] * g1.number_of_nodes() +\
[feat2] * g2.number_of_nodes() +\
[feat3] * g3.number_of_nodes(), 0
)
assert F.allclose(dgl.broadcast_nodes(
bg, F.stack([feat0, feat1, feat2, feat3], 0)
), ground_truth)
def test_broadcast_edges():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,))
ground_truth = F.stack([feat0] * g0.number_of_edges(), 0)
assert F.allclose(dgl.broadcast_edges(g0, feat0), ground_truth)
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(3))
g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,))
feat2 = F.randn((40,))
feat3 = F.randn((40,))
ground_truth = F.stack(
[feat0] * g0.number_of_edges() +\
[feat1] * g1.number_of_edges() +\
[feat2] * g2.number_of_edges() +\
[feat3] * g3.number_of_edges(), 0
)
assert F.allclose(dgl.broadcast_edges(
bg, F.stack([feat0, feat1, feat2, feat3], 0)
), ground_truth)
if __name__ == '__main__':
test_simple_readout()
test_topk_nodes()
test_topk_edges()
test_softmax_nodes()
test_softmax_edges()
test_broadcast_nodes()
test_broadcast_edges()
......@@ -3,11 +3,10 @@ import networkx as nx
import numpy as np
import dgl
import dgl.nn.mxnet as nn
from mxnet import autograd
from mxnet import autograd, gluon
def check_eq(a, b):
assert a.shape == b.shape
assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
def check_close(a, b):
assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)
def _AXWb(A, X, W, b):
X = mx.nd.dot(X, W.data(X.context))
......@@ -26,13 +25,13 @@ def test_graph_conv():
h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias))
check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim
h0 = mx.nd.ones((3, 5, 5))
h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias))
check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2)
conv.initialize(ctx=ctx)
......@@ -69,7 +68,92 @@ def test_graph_conv():
assert len(g.ndata) == 1
assert len(g.edata) == 0
assert "h" in g.ndata
check_eq(g.ndata['h'], 2 * mx.nd.ones((3, 1)))
check_close(g.ndata['h'], 2 * mx.nd.ones((3, 1)))
def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
print(s2s)
# test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5)
h1 = s2s(h0, g)
assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph
bg = dgl.batch([g, g, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5)
h1 = s2s(h0, bg)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2
def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
print(gap)
# test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5)
h1 = gap(h0, g)
assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5)
h1 = gap(h0, bg)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
def test_simple_pool():
g = dgl.DGLGraph(nx.path_graph(15))
sum_pool = nn.SumPooling()
avg_pool = nn.AvgPooling()
max_pool = nn.MaxPooling()
sort_pool = nn.SortPooling(10) # k = 10
print(sum_pool, avg_pool, max_pool, sort_pool)
# test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5)
h1 = sum_pool(h0, g)
check_close(h1, mx.nd.sum(h0, 0))
h1 = avg_pool(h0, g)
check_close(h1, mx.nd.mean(h0, 0))
h1 = max_pool(h0, g)
check_close(h1, mx.nd.max(h0, 0))
h1 = sort_pool(h0, g)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5)
h1 = sum_pool(h0, bg)
truth = mx.nd.stack(mx.nd.sum(h0[:15], 0),
mx.nd.sum(h0[15:20], 0),
mx.nd.sum(h0[20:35], 0),
mx.nd.sum(h0[35:40], 0),
mx.nd.sum(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = avg_pool(h0, bg)
truth = mx.nd.stack(mx.nd.mean(h0[:15], 0),
mx.nd.mean(h0[15:20], 0),
mx.nd.mean(h0[20:35], 0),
mx.nd.mean(h0[35:40], 0),
mx.nd.mean(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = max_pool(h0, bg)
truth = mx.nd.stack(mx.nd.max(h0[:15], 0),
mx.nd.max(h0[15:20], 0),
mx.nd.max(h0[20:35], 0),
mx.nd.max(h0[35:40], 0),
mx.nd.max(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = sort_pool(h0, bg)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2
def uniform_attention(g, shape):
a = mx.nd.ones(shape)
......@@ -97,3 +181,6 @@ def test_edge_softmax():
if __name__ == '__main__':
test_graph_conv()
test_edge_softmax()
test_set2set()
test_glob_att_pool()
test_simple_pool()
......@@ -61,6 +61,124 @@ def test_graph_conv():
new_weight = conv.weight.data
assert not th.allclose(old_weight, new_weight)
def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
print(s2s)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 5)
h1 = s2s(h0, g)
assert h1.shape[0] == 10 and h1.dim() == 1
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(11))
g2 = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g1, g2])
h0 = th.rand(bg.number_of_nodes(), 5)
h1 = s2s(h0, bg)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2
def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
print(gap)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 5)
h1 = gap(h0, g)
assert h1.shape[0] == 10 and h1.dim() == 1
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
h0 = th.rand(bg.number_of_nodes(), 5)
h1 = gap(h0, bg)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2
def test_simple_pool():
g = dgl.DGLGraph(nx.path_graph(15))
sum_pool = nn.SumPooling()
avg_pool = nn.AvgPooling()
max_pool = nn.MaxPooling()
sort_pool = nn.SortPooling(10) # k = 10
print(sum_pool, avg_pool, max_pool, sort_pool)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 5)
h1 = sum_pool(h0, g)
assert th.allclose(h1, th.sum(h0, 0))
h1 = avg_pool(h0, g)
assert th.allclose(h1, th.mean(h0, 0))
h1 = max_pool(h0, g)
assert th.allclose(h1, th.max(h0, 0)[0])
h1 = sort_pool(h0, g)
assert h1.shape[0] == 10 * 5 and h1.dim() == 1
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g])
h0 = th.rand(bg.number_of_nodes(), 5)
h1 = sum_pool(h0, bg)
truth = th.stack([th.sum(h0[:15], 0),
th.sum(h0[15:20], 0),
th.sum(h0[20:35], 0),
th.sum(h0[35:40], 0),
th.sum(h0[40:55], 0)], 0)
assert th.allclose(h1, truth)
h1 = avg_pool(h0, bg)
truth = th.stack([th.mean(h0[:15], 0),
th.mean(h0[15:20], 0),
th.mean(h0[20:35], 0),
th.mean(h0[35:40], 0),
th.mean(h0[40:55], 0)], 0)
assert th.allclose(h1, truth)
h1 = max_pool(h0, bg)
truth = th.stack([th.max(h0[:15], 0)[0],
th.max(h0[15:20], 0)[0],
th.max(h0[20:35], 0)[0],
th.max(h0[35:40], 0)[0],
th.max(h0[40:55], 0)[0]], 0)
assert th.allclose(h1, truth)
h1 = sort_pool(h0, bg)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2
def test_set_trans():
g = dgl.DGLGraph(nx.path_graph(15))
st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
print(st_enc_0, st_enc_1, st_dec)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 50)
h1 = st_enc_0(h0, g)
assert h1.shape == h0.shape
h1 = st_enc_1(h0, g)
assert h1.shape == h0.shape
h2 = st_dec(h1, g)
assert h2.shape[0] == 200 and h2.dim() == 1
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5))
g2 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g, g1, g2])
h0 = th.rand(bg.number_of_nodes(), 50)
h1 = st_enc_0(h0, bg)
assert h1.shape == h0.shape
h1 = st_enc_1(h0, bg)
assert h1.shape == h0.shape
h2 = st_dec(h1, bg)
assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2
def uniform_attention(g, shape):
a = th.ones(shape)
target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
......@@ -130,3 +248,7 @@ def test_edge_softmax():
if __name__ == '__main__':
test_graph_conv()
test_edge_softmax()
test_set2set()
test_glob_att_pool()
test_simple_pool()
test_set_trans()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment