Unverified Commit 5d3f470b authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] DGL Pooling modules (#669)

* removal doc

* glob

* upd

* rm knn

* add softmax

* upd

* upd

* add broadcast and s2s

* optimize max_on

* forsaken changes to heterograph

* upd

* upd

* upd

* upd

* upd

* bugfix

* upd

* upd

* upd

* upd

* format upd

* upd format

* upd doc

* upd

* import order

* upd

* rm warnings

* fix

* upd test

* upd

* upd

* fix device

* upd

* upd

* upd

* upd

* remove 1.1

* upd

* trigger

* trigger

* add more tests

* fix device

* upd

* upd

* refactor

* fix?

* fix

* upd docstring

* refactor

* upd

* fix

* upd

* upd

* upd

* fix

* upd docs

* add shape

* refactor & upd doc

* upd doc

* upd
parent c3516f1a
...@@ -37,3 +37,9 @@ Graph Readout ...@@ -37,3 +37,9 @@ Graph Readout
mean_edges mean_edges
max_nodes max_nodes
max_edges max_edges
topk_nodes
topk_edges
softmax_nodes
softmax_edges
broadcast_nodes
broadcast_edges
...@@ -11,3 +11,18 @@ dgl.nn.mxnet.conv ...@@ -11,3 +11,18 @@ dgl.nn.mxnet.conv
.. autoclass:: dgl.nn.mxnet.conv.GraphConv .. autoclass:: dgl.nn.mxnet.conv.GraphConv
:members: weight, bias, forward :members: weight, bias, forward
:show-inheritance: :show-inheritance:
dgl.nn.mxnet.glob
-----------------
.. automodule:: dgl.nn.mxnet.glob
:members:
dgl.nn.mxnet.softmax
--------------------
.. automodule:: dgl.nn.mxnet.softmax
.. autoclass:: dgl.nn.mxnet.softmax.EdgeSoftmax
:members: forward
:show-inheritance:
...@@ -12,6 +12,43 @@ dgl.nn.pytorch.conv ...@@ -12,6 +12,43 @@ dgl.nn.pytorch.conv
:members: weight, bias, forward, reset_parameters :members: weight, bias, forward, reset_parameters
:show-inheritance: :show-inheritance:
dgl.nn.pytorch.glob
-------------------
.. automodule:: dgl.nn.pytorch.glob
.. autoclass:: dgl.nn.pytorch.glob.SumPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.AvgPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.MaxPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.SortPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.GlobalAttentionPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.Set2Set
:members: forward
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.SetTransformerEncoder
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.SetTransformerDecoder
:members:
:show-inheritance:
dgl.nn.pytorch.softmax dgl.nn.pytorch.softmax
---------------------- ----------------------
......
...@@ -74,6 +74,21 @@ def tensor(data, dtype=None): ...@@ -74,6 +74,21 @@ def tensor(data, dtype=None):
""" """
pass pass
def as_scalar(data):
"""Returns a scalar whose value is copied from this array.
Parameters
----------
data : Tensor
The input data
Returns
-------
scalar
The scalar value in the tensor.
"""
pass
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -293,6 +308,21 @@ def sum(input, dim): ...@@ -293,6 +308,21 @@ def sum(input, dim):
""" """
pass pass
def reduce_sum(input):
"""Returns the sum of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def mean(input, dim): def mean(input, dim):
"""Reduce average the input tensor along the given dim. """Reduce average the input tensor along the given dim.
...@@ -310,6 +340,21 @@ def mean(input, dim): ...@@ -310,6 +340,21 @@ def mean(input, dim):
""" """
pass pass
def reduce_mean(input):
"""Returns the average of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def max(input, dim): def max(input, dim):
"""Reduce max the input tensor along the given dim. """Reduce max the input tensor along the given dim.
...@@ -327,6 +372,121 @@ def max(input, dim): ...@@ -327,6 +372,121 @@ def max(input, dim):
""" """
pass pass
def reduce_max(input):
"""Returns the max of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def min(input, dim):
"""Reduce min the input tensor along the given dim.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The reduce dim.
Returns
-------
Tensor
A framework-specific tensor.
"""
pass
def reduce_min(input):
"""Returns the min of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def argsort(input, dim, descending):
"""Return the indices that would sort the input along the given dim.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The dim to sort along.
descending : bool
Controls the sorting order (False: ascending, True: descending)
Returns
-------
Tensor
A framework-specific tensor.
"""
def topk(input, k, dim, descending=True):
"""Return the k largest elements of the given input tensor along the given dimension.
If descending is False then the k smallest elements are returned.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The dim to sort along.
descending : bool
Controls whether to return largest/smallest elements.
"""
pass
def exp(input):
"""Returns a new tensor with the exponential of the elements of the input tensor `input`.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
The output tensor.
"""
pass
def softmax(input, dim=-1):
"""Apply the softmax function on given dimension.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The dimension along which to compute softmax.
Returns
-------
Tensor
The output tensor.
"""
pass
def cat(seq, dim): def cat(seq, dim):
"""Concat the sequence of tensors in the given dimension. """Concat the sequence of tensors in the given dimension.
...@@ -381,6 +541,25 @@ def split(input, sizes_or_sections, dim): ...@@ -381,6 +541,25 @@ def split(input, sizes_or_sections, dim):
""" """
pass pass
def repeat(input, repeats, dim):
"""Repeats elements of an array.
Parameters
----------
input : Tensor
Input data array
repeats : int
The number of repetitions for each element
dim : int
The dim along which to repeat values.
Returns
-------
Tensor
The obtained tensor.
"""
pass
def gather_row(data, row_index): def gather_row(data, row_index):
"""Slice out the data given the row index. """Slice out the data given the row index.
...@@ -398,6 +577,41 @@ def gather_row(data, row_index): ...@@ -398,6 +577,41 @@ def gather_row(data, row_index):
""" """
pass pass
def slice_axis(data, axis, begin, end):
"""Slice along a given axis.
Returns an array slice along a given axis starting from :attr:`begin` index to :attr:`end` index.
Parameters
----------
data : Tensor
The data tensor.
axis : int
The axis along to slice the tensor.
begin : int
Indicates the begin index.
end : int
Indicates the end index.
Returns:
--------
Tensor
The sliced tensor.
"""
pass
def take(data, indices, dim):
"""Takes elements from an input array along the given dim.
Parameters
----------
data : Tensor
The data tensor.
indices : Tensor
The indices tensor.
dim : Tensor
The dimension to gather along.
"""
pass
def narrow_row(x, start, stop): def narrow_row(x, start, stop):
"""Narrow down the tensor along the first dimension. """Narrow down the tensor along the first dimension.
...@@ -563,6 +777,50 @@ def ones(shape, dtype, ctx): ...@@ -563,6 +777,50 @@ def ones(shape, dtype, ctx):
""" """
pass pass
def pad_packed_tensor(input, lengths, value, l_min=None):
"""Pads a packed batch of variable length tensors with given value.
Parameters
----------
input : Tensor
The input tensor with shape :math:`(N, *)`
lengths : list or tensor
The array of tensor lengths (of the first dimension) :math:`L`.
It should satisfy :math:`\sum_{i=1}^{B}L_i = N`,
where :math:`B` is the length of :math:`L`.
value : float
The value to fill in the tensor.
l_min : int or None, defaults to None.
The minimum length each tensor need to be padded to, if set to None,
then there is no minimum length requirement.
Returns
-------
Tensor
The obtained tensor with shape :math:`(B, \max(\max_i(L_i), l_{min}), *)`
"""
pass
def pack_padded_tensor(input, lengths):
"""Packs a tensor containing padded sequence of variable length.
Parameters
----------
input : Tensor
The input tensor with shape :math:`(B, L, *)`, where :math:`B` is
the batch size and :math:`L` is the maximum length of the batch.
lengths : list or tensor
The array of tensor lengths (of the first dimension) :math:`L`.
:math:`\max_i(L_i)` should equal :math:`L`.
Returns
-------
Tensor
The obtained tensor with shape :math:`(N, *)` where
:math:`N = \sum_{i=1}^{B}L_i`
"""
pass
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
"""Computes the sum along segments of a tensor. """Computes the sum along segments of a tensor.
......
...@@ -6,6 +6,7 @@ import numpy as np ...@@ -6,6 +6,7 @@ import numpy as np
import mxnet as mx import mxnet as mx
import mxnet.ndarray as nd import mxnet.ndarray as nd
import numbers import numbers
import builtins
from ... import ndarray as dglnd from ... import ndarray as dglnd
from ... import kernel as K from ... import kernel as K
...@@ -38,6 +39,9 @@ def tensor(data, dtype=None): ...@@ -38,6 +39,9 @@ def tensor(data, dtype=None):
dtype = np.float32 dtype = np.float32
return nd.array(data, dtype=dtype) return nd.array(data, dtype=dtype)
def as_scalar(data):
return data.asscalar()
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -112,12 +116,41 @@ def copy_to(input, ctx): ...@@ -112,12 +116,41 @@ def copy_to(input, ctx):
def sum(input, dim): def sum(input, dim):
return nd.sum(input, axis=dim) return nd.sum(input, axis=dim)
def reduce_sum(input):
return input.sum()
def mean(input, dim): def mean(input, dim):
return nd.mean(input, axis=dim) return nd.mean(input, axis=dim)
def reduce_mean(input):
return input.mean()
def max(input, dim): def max(input, dim):
return nd.max(input, axis=dim) return nd.max(input, axis=dim)
def reduce_max(input):
return input.max()
def min(input, dim):
return nd.min(input, axis=dim)
def reduce_min(input):
return input.min()
def topk(input, k, dim, descending=True):
return nd.topk(input, axis=dim, k=k, ret_typ='value', is_ascend=not descending)
def argsort(input, dim, descending):
idx = nd.argsort(input, dim, is_ascend=not descending)
idx = nd.cast(idx, dtype='int64')
return idx
def exp(input):
return nd.exp(input)
def softmax(input, dim=-1):
return nd.softmax(input, axis=dim)
def cat(seq, dim): def cat(seq, dim):
return nd.concat(*seq, dim=dim) return nd.concat(*seq, dim=dim)
...@@ -143,6 +176,9 @@ def split(x, sizes_or_sections, dim): ...@@ -143,6 +176,9 @@ def split(x, sizes_or_sections, dim):
else: else:
return nd.split(x, sizes_or_sections, axis=dim) return nd.split(x, sizes_or_sections, axis=dim)
def repeat(input, repeats, dim):
return nd.repeat(input, repeats, axis=dim)
def gather_row(data, row_index): def gather_row(data, row_index):
# MXNet workaround for empty row index # MXNet workaround for empty row index
if len(row_index) == 0: if len(row_index) == 0:
...@@ -153,6 +189,17 @@ def gather_row(data, row_index): ...@@ -153,6 +189,17 @@ def gather_row(data, row_index):
else: else:
return data[row_index,] return data[row_index,]
def slice_axis(data, axis, begin, end):
dim = data.shape[axis]
if begin < 0:
begin += dim
if end <= 0:
end += dim
return nd.slice_axis(data, axis, begin, end)
def take(data, indices, dim):
return nd.take(data, indices, dim)
def narrow_row(data, start, stop): def narrow_row(data, start, stop):
return data[start:stop] return data[start:stop]
...@@ -181,6 +228,35 @@ def zeros_like(input): ...@@ -181,6 +228,35 @@ def zeros_like(input):
def ones(shape, dtype, ctx): def ones(shape, dtype, ctx):
return nd.ones(shape, dtype=dtype, ctx=ctx) return nd.ones(shape, dtype=dtype, ctx=ctx)
def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape
if isinstance(lengths, nd.NDArray):
max_len = as_scalar(input.max())
else:
max_len = builtins.max(lengths)
if l_min is not None:
max_len = builtins.max(max_len, l_min)
batch_size = len(lengths)
ctx = input.context
dtype = input.dtype
x = nd.full((batch_size * max_len, *old_shape[1:]), value, ctx=ctx, dtype=dtype)
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = nd.array(index, ctx=ctx)
return scatter_row(x, index, input).reshape(batch_size, max_len, *old_shape[1:])
def pack_padded_tensor(input, lengths):
batch_size, max_len = input.shape[:2]
ctx = input.context
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = nd.array(index, ctx=ctx)
return gather_row(input.reshape(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
# TODO: support other dimensions # TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment sum on first dimension' assert dim == 0, 'MXNet only supports segment sum on first dimension'
......
...@@ -22,6 +22,11 @@ def cpu(): ...@@ -22,6 +22,11 @@ def cpu():
def tensor(data, dtype=None): def tensor(data, dtype=None):
return np.array(data, dtype) return np.array(data, dtype)
def as_scalar(data):
if data.dim() > 1:
raise ValueError('The data must have shape (1,).')
return data[0]
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -75,9 +80,46 @@ def copy_to(input, ctx): ...@@ -75,9 +80,46 @@ def copy_to(input, ctx):
def sum(input, dim): def sum(input, dim):
return np.sum(input, axis=dim) return np.sum(input, axis=dim)
def reduce_sum(input):
dtype = input.dtype
return np.array(input.sum(), dtype=dtype)
def mean(input, dim):
return np.mean(input, axis=dim)
def reduce_mean(input):
dtype = input.dtype
return np.array(input.mean(), dtype=dtype)
def max(input, dim): def max(input, dim):
return np.max(input, axis=dim) return np.max(input, axis=dim)
def reduce_max(input):
dtype = input.dtype
return np.array(input.max(), dtype=dtype)
def min(input, dim):
return np.min(input, axis=dim)
def reduce_min(input):
dtype = input.dtype
return np.array(input.min(), dtype=dtype)
def argsort(input, dim, descending):
if descending:
return np.argsort(-input, axis=dim)
return np.argsort(input, axis=dim)
def exp(input):
return np.exp(input)
def softmax(input, dim=-1):
max_val = input.max(axis=dim)
minus_max = input - np.expand_dims(max_val, axis=dim)
exp_val = np.exp(minus_max)
sum_val = np.sum(exp_val, axis=dim)
return exp_val / np.expand_dims(sum_val, axis=dim)
def cat(seq, dim): def cat(seq, dim):
return np.concatenate(seq, axis=dim) return np.concatenate(seq, axis=dim)
...@@ -92,9 +134,20 @@ def split(input, sizes_or_sections, dim): ...@@ -92,9 +134,20 @@ def split(input, sizes_or_sections, dim):
idx = np.cumsum(sizes_or_sections)[0:-1] idx = np.cumsum(sizes_or_sections)[0:-1]
return np.split(input, idx, axis=dim) return np.split(input, idx, axis=dim)
def repeat(input, repeats, dim):
return np.repeat(input, repeats, axis=dim)
def gather_row(data, row_index): def gather_row(data, row_index):
return data[row_index] return data[row_index]
def slice_axis(data, axis, begin, end):
if begin >= end:
raise IndexError("Begin index ({}) equals or greater than end index ({})".format(begin, end))
return np.take(data, np.arange(begin, end), axis=axis)
def take(data, indices, dim):
return np.take(data, indices, axis=dim)
def scatter_row(data, row_index, value): def scatter_row(data, row_index, value):
# NOTE: inplace instead of out-place # NOTE: inplace instead of out-place
data[row_index] = value data[row_index] = value
......
...@@ -3,6 +3,7 @@ from __future__ import absolute_import ...@@ -3,6 +3,7 @@ from __future__ import absolute_import
from distutils.version import LooseVersion from distutils.version import LooseVersion
import torch as th import torch as th
import builtins
from torch.utils import dlpack from torch.utils import dlpack
from ... import ndarray as nd from ... import ndarray as nd
...@@ -26,6 +27,9 @@ def cpu(): ...@@ -26,6 +27,9 @@ def cpu():
def tensor(data, dtype=None): def tensor(data, dtype=None):
return th.tensor(data, dtype=dtype) return th.tensor(data, dtype=dtype)
def as_scalar(data):
return data.item()
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -90,13 +94,41 @@ def copy_to(input, ctx): ...@@ -90,13 +94,41 @@ def copy_to(input, ctx):
def sum(input, dim): def sum(input, dim):
return th.sum(input, dim=dim) return th.sum(input, dim=dim)
def reduce_sum(input):
return input.sum()
def mean(input, dim): def mean(input, dim):
return th.mean(input, dim=dim) return th.mean(input, dim=dim)
def reduce_mean(input):
return input.mean()
def max(input, dim): def max(input, dim):
# NOTE: the second argmax array is not returned # NOTE: the second argmax array is not returned
return th.max(input, dim=dim)[0] return th.max(input, dim=dim)[0]
def reduce_max(input):
return input.max()
def min(input, dim):
# NOTE: the second argmin array is not returned
return th.min(input, dim=dim)[0]
def reduce_min(input):
return input.min()
def argsort(input, dim, descending):
return th.argsort(input, dim=dim, descending=descending)
def topk(input, k, dim, descending=True):
return th.topk(input, k, dim, largest=descending)[0]
def exp(input):
return th.exp(input)
def softmax(input, dim=-1):
return th.softmax(input, dim=dim)
def cat(seq, dim): def cat(seq, dim):
return th.cat(seq, dim=dim) return th.cat(seq, dim=dim)
...@@ -106,9 +138,29 @@ def stack(seq, dim): ...@@ -106,9 +138,29 @@ def stack(seq, dim):
def split(input, sizes_or_sections, dim): def split(input, sizes_or_sections, dim):
return th.split(input, sizes_or_sections, dim) return th.split(input, sizes_or_sections, dim)
def repeat(input, repeats, dim):
# return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1
if dim < 0:
dim += input.dim()
return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1)
def gather_row(data, row_index): def gather_row(data, row_index):
return th.index_select(data, 0, row_index) return th.index_select(data, 0, row_index)
def slice_axis(data, axis, begin, end):
dim = data.shape[axis]
if begin < 0:
begin += dim
if end <= 0:
end += dim
if begin >= end:
raise IndexError("Begin index ({}) equals or greater than end index ({})".format(begin, end))
return th.index_select(data, axis, th.arange(begin, end, device=data.device))
def take(data, indices, dim):
new_shape = data.shape[:dim] + indices.shape + data.shape[dim+1:]
return th.index_select(data, dim, indices.view(-1)).view(new_shape)
def narrow_row(x, start, stop): def narrow_row(x, start, stop):
return x[start:stop] return x[start:stop]
...@@ -136,6 +188,35 @@ def zeros_like(input): ...@@ -136,6 +188,35 @@ def zeros_like(input):
def ones(shape, dtype, ctx): def ones(shape, dtype, ctx):
return th.ones(shape, dtype=dtype, device=ctx) return th.ones(shape, dtype=dtype, device=ctx)
def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape
if isinstance(lengths, th.Tensor):
max_len = as_scalar(lengths.max())
else:
max_len = builtins.max(lengths)
if l_min is not None:
max_len = builtins.max(max_len, l_min)
batch_size = len(lengths)
device = input.device
x = input.new(batch_size * max_len, *old_shape[1:])
x.fill_(value)
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = th.tensor(index).to(device)
return scatter_row(x, index, input).view(batch_size, max_len, *old_shape[1:])
def pack_padded_tensor(input, lengths):
batch_size, max_len = input.shape[:2]
device = input.device
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = th.tensor(index).to(device)
return gather_row(input.view(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input) y = th.zeros(n_segs, *input.shape[1:]).to(input)
seg_id = seg_id.view((-1,) + (1,) * (input.dim() - 1)).expand_as(input) seg_id = seg_id.view((-1,) + (1,) * (input.dim() - 1)).expand_as(input)
......
This diff is collapsed.
"""Package for mxnet-specific NN modules.""" """Package for mxnet-specific NN modules."""
from .conv import * from .conv import *
from .glob import *
from .softmax import * from .softmax import *
"""MXNet modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, C0103, W0235
import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon import nn
from ... import BatchedDGLGraph
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'GlobalAttentionPooling', 'Set2Set']
class SumPooling(nn.Block):
r"""Apply sum pooling over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(SumPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute sum pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = sum_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'SumPooling()'
class AvgPooling(nn.Block):
r"""Apply average pooling over the nodes in the graph.
.. math::
r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(AvgPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute average pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = mean_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'AvgPooling()'
class MaxPooling(nn.Block):
r"""Apply max pooling over the nodes in the graph.
.. math::
r^{(i)} = \max_{k=1}^{N_i} \left( x^{(i)}_k \right)
"""
def __init__(self):
super(MaxPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute max pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = max_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'MaxPooling()'
class SortPooling(nn.Block):
r"""Apply Sort Pooling (`An End-to-End Deep Learning Architecture for Graph Classification
<https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__) over the nodes in the graph.
Parameters
----------
k : int
The number of nodes to hold for each graph.
"""
def __init__(self, k):
super(SortPooling, self).__init__()
self.k = k
def forward(self, feat, graph):
r"""Compute sort pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(k * D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, k * D)`.
"""
# Sort the feature of each node in ascending order.
with graph.local_scope():
feat = feat.sort(axis=-1)
graph.ndata['h'] = feat
# Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k)[0].reshape(
-1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph):
return ret
else:
return ret.squeeze(axis=0)
def __repr__(self):
return 'SortPooling(k={})'.format(self.k)
class GlobalAttentionPooling(nn.Block):
r"""Apply Global Attention Pooling (`Gated Graph Sequence Neural Networks
<https://arxiv.org/abs/1511.05493.pdf>`__) over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate}
\left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)
Parameters
----------
gate_nn : gluon.nn.Block
A neural network that computes attention scores for each feature.
feat_nn : gluon.nn.Block, optional
A neural network applied to each feature before combining them
with attention scores.
"""
def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__()
with self.name_scope():
self.gate_nn = gate_nn
self.feat_nn = feat_nn
self._reset_parameters()
def _reset_parameters(self):
self.gate_nn.initialize(mx.init.Xavier())
if self.feat_nn:
self.feat_nn.initialize(mx.init.Xavier())
def forward(self, feat, graph):
r"""Compute global attention pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate
gate = softmax_nodes(graph, 'gate')
graph.ndata['r'] = feat * gate
readout = sum_nodes(graph, 'r')
return readout
class Set2Set(nn.Block):
r"""Apply Set2Set (`Order Matters: Sequence to sequence for sets
<https://arxiv.org/pdf/1511.06391.pdf>`__) over the nodes in the graph.
For each individual graph in the batch, set2set computes
.. math::
q_t &= \mathrm{LSTM} (q^*_{t-1})
\alpha_{i,t} &= \mathrm{softmax}(x_i \cdot q_t)
r_t &= \sum_{i=1}^N \alpha_{i,t} x_i
q^*_t &= q_t \Vert r_t
for this graph.
Parameters
----------
input_dim : int
Size of each input sample
n_iters : int
Number of iterations.
n_layers : int
Number of recurrent layers.
"""
def __init__(self, input_dim, n_iters, n_layers):
super(Set2Set, self).__init__()
self.input_dim = input_dim
self.output_dim = 2 * input_dim
self.n_iters = n_iters
self.n_layers = n_layers
with self.name_scope():
self.lstm = gluon.rnn.LSTM(
self.input_dim, num_layers=n_layers, input_size=self.output_dim)
self._reset_parameters()
def _reset_parameters(self):
self.lstm.initialize(mx.init.Xavier())
def forward(self, feat, graph):
r"""Compute set2set pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
"""
with graph.local_scope():
batch_size = 1
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
h = (nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context),
nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context))
q_star = nd.zeros((batch_size, self.output_dim), ctx=feat.context)
for _ in range(self.n_iters):
q, h = self.lstm(q_star.expand_dims(axis=0), h)
q = q.reshape((batch_size, self.input_dim))
e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True)
graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r')
if readout.ndim == 1: # graph is not a BatchedDGLGraph
readout = readout.expand_dims(0)
q_star = nd.concat(q, readout, dim=-1)
if isinstance(graph, BatchedDGLGraph):
return q_star
else:
return q_star.squeeze(axis=0)
def __repr__(self):
summary = 'Set2Set('
summary += 'in={}, out={}, ' \
'n_iters={}, n_layers={}'.format(self.input_dim,
self.output_dim,
self.n_iters,
self.n_layers)
summary += ')'
return summary
"""Package for pytorch-specific NN modules.""" """Package for pytorch-specific NN modules."""
from .conv import * from .conv import *
from .glob import *
from .softmax import * from .softmax import *
This diff is collapsed.
import dgl import dgl
import backend as F import backend as F
import networkx as nx
def test_simple_readout(): def test_simple_readout():
g1 = dgl.DGLGraph() g1 = dgl.DGLGraph()
...@@ -57,8 +58,192 @@ def test_simple_readout(): ...@@ -57,8 +58,192 @@ def test_simple_readout():
max_bg_e = dgl.max_edges(g, 'x') max_bg_e = dgl.max_edges(g, 'x')
assert F.allclose(s, F.stack([se1, F.zeros(5)], 0)) assert F.allclose(s, F.stack([se1, F.zeros(5)], 0))
assert F.allclose(m, F.stack([me1, F.zeros(5)], 0)) assert F.allclose(m, F.stack([me1, F.zeros(5)], 0))
assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0)) # TODO(zihao): fix -inf issue
# assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
def test_topk_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(14))
feat0 = F.randn((g0.number_of_nodes(), 10))
g0.ndata['x'] = feat0
# to test the case where k > number of nodes.
dgl.topk_nodes(g0, 'x', 20, idx=-1)
# test correctness
val, indices = dgl.topk_nodes(g0, 'x', 5, idx=-1)
ground_truth = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 9, 10), 0, True)[:5], (5,))
assert F.allclose(ground_truth, indices)
g0.ndata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(12))
feat1 = F.randn((g1.number_of_nodes(), 10))
bg = dgl.batch([g0, g1])
bg.ndata['x'] = F.cat([feat0, feat1], 0)
# to test the case where k > number of nodes.
dgl.topk_nodes(bg, 'x', 16, idx=1)
# test correctness
val, indices = dgl.topk_nodes(bg, 'x', 6, descending=False, idx=0)
ground_truth_0 = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 0, 1), 0, False)[:6], (6,))
ground_truth_1 = F.reshape(
F.argsort(F.slice_axis(feat1, -1, 0, 1), 0, False)[:6], (6,))
ground_truth = F.stack([ground_truth_0, ground_truth_1], 0)
assert F.allclose(ground_truth, indices)
# test idx=None
val, indices = dgl.topk_nodes(bg, 'x', 6, descending=True)
assert F.allclose(val, F.stack([F.topk(feat0, 6, 0), F.topk(feat1, 6, 0)], 0))
def test_topk_edges():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(14))
feat0 = F.randn((g0.number_of_edges(), 10))
g0.edata['x'] = feat0
# to test the case where k > number of edges.
dgl.topk_edges(g0, 'x', 30, idx=-1)
# test correctness
val, indices = dgl.topk_edges(g0, 'x', 7, idx=-1)
ground_truth = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 9, 10), 0, True)[:7], (7,))
assert F.allclose(ground_truth, indices)
g0.edata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(12))
feat1 = F.randn((g1.number_of_edges(), 10))
bg = dgl.batch([g0, g1])
bg.edata['x'] = F.cat([feat0, feat1], 0)
# to test the case where k > number of edges.
dgl.topk_edges(bg, 'x', 33, idx=1)
# test correctness
val, indices = dgl.topk_edges(bg, 'x', 4, descending=False, idx=0)
ground_truth_0 = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 0, 1), 0, False)[:4], (4,))
ground_truth_1 = F.reshape(
F.argsort(F.slice_axis(feat1, -1, 0, 1), 0, False)[:4], (4,))
ground_truth = F.stack([ground_truth_0, ground_truth_1], 0)
assert F.allclose(ground_truth, indices)
# test idx=None
val, indices = dgl.topk_edges(bg, 'x', 6, descending=True)
assert F.allclose(val, F.stack([F.topk(feat0, 6, 0), F.topk(feat1, 6, 0)], 0))
def test_softmax_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(9))
feat0 = F.randn((g0.number_of_nodes(), 10))
g0.ndata['x'] = feat0
ground_truth = F.softmax(feat0, dim=0)
assert F.allclose(dgl.softmax_nodes(g0, 'x'), ground_truth)
g0.ndata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5))
g2 = dgl.DGLGraph(nx.path_graph(3))
g3 = dgl.DGLGraph()
g4 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g0, g1, g2, g3, g4])
feat1 = F.randn((g1.number_of_nodes(), 10))
feat2 = F.randn((g2.number_of_nodes(), 10))
feat4 = F.randn((g4.number_of_nodes(), 10))
bg.ndata['x'] = F.cat([feat0, feat1, feat2, feat4], 0)
ground_truth = F.cat([
F.softmax(feat0, 0),
F.softmax(feat1, 0),
F.softmax(feat2, 0),
F.softmax(feat4, 0)
], 0)
assert F.allclose(dgl.softmax_nodes(bg, 'x'), ground_truth)
def test_softmax_edges():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((g0.number_of_edges(), 10))
g0.edata['x'] = feat0
ground_truth = F.softmax(feat0, dim=0)
assert F.allclose(dgl.softmax_edges(g0, 'x'), ground_truth)
g0.edata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5))
g2 = dgl.DGLGraph(nx.path_graph(3))
g3 = dgl.DGLGraph()
g4 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g0, g1, g2, g3, g4])
feat1 = F.randn((g1.number_of_edges(), 10))
feat2 = F.randn((g2.number_of_edges(), 10))
feat4 = F.randn((g4.number_of_edges(), 10))
bg.edata['x'] = F.cat([feat0, feat1, feat2, feat4], 0)
ground_truth = F.cat([
F.softmax(feat0, 0),
F.softmax(feat1, 0),
F.softmax(feat2, 0),
F.softmax(feat4, 0)
], 0)
assert F.allclose(dgl.softmax_edges(bg, 'x'), ground_truth)
def test_broadcast_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,))
ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0)
assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth)
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(3))
g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,))
feat2 = F.randn((40,))
feat3 = F.randn((40,))
ground_truth = F.stack(
[feat0] * g0.number_of_nodes() +\
[feat1] * g1.number_of_nodes() +\
[feat2] * g2.number_of_nodes() +\
[feat3] * g3.number_of_nodes(), 0
)
assert F.allclose(dgl.broadcast_nodes(
bg, F.stack([feat0, feat1, feat2, feat3], 0)
), ground_truth)
def test_broadcast_edges():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,))
ground_truth = F.stack([feat0] * g0.number_of_edges(), 0)
assert F.allclose(dgl.broadcast_edges(g0, feat0), ground_truth)
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(3))
g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,))
feat2 = F.randn((40,))
feat3 = F.randn((40,))
ground_truth = F.stack(
[feat0] * g0.number_of_edges() +\
[feat1] * g1.number_of_edges() +\
[feat2] * g2.number_of_edges() +\
[feat3] * g3.number_of_edges(), 0
)
assert F.allclose(dgl.broadcast_edges(
bg, F.stack([feat0, feat1, feat2, feat3], 0)
), ground_truth)
if __name__ == '__main__': if __name__ == '__main__':
test_simple_readout() test_simple_readout()
test_topk_nodes()
test_topk_edges()
test_softmax_nodes()
test_softmax_edges()
test_broadcast_nodes()
test_broadcast_edges()
...@@ -3,11 +3,10 @@ import networkx as nx ...@@ -3,11 +3,10 @@ import networkx as nx
import numpy as np import numpy as np
import dgl import dgl
import dgl.nn.mxnet as nn import dgl.nn.mxnet as nn
from mxnet import autograd from mxnet import autograd, gluon
def check_eq(a, b): def check_close(a, b):
assert a.shape == b.shape assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)
assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
def _AXWb(A, X, W, b): def _AXWb(A, X, W, b):
X = mx.nd.dot(X, W.data(X.context)) X = mx.nd.dot(X, W.data(X.context))
...@@ -26,13 +25,13 @@ def test_graph_conv(): ...@@ -26,13 +25,13 @@ def test_graph_conv():
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias)) check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim # test#2: more-dim
h0 = mx.nd.ones((3, 5, 5)) h0 = mx.nd.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias)) check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
...@@ -69,7 +68,92 @@ def test_graph_conv(): ...@@ -69,7 +68,92 @@ def test_graph_conv():
assert len(g.ndata) == 1 assert len(g.ndata) == 1
assert len(g.edata) == 0 assert len(g.edata) == 0
assert "h" in g.ndata assert "h" in g.ndata
check_eq(g.ndata['h'], 2 * mx.nd.ones((3, 1))) check_close(g.ndata['h'], 2 * mx.nd.ones((3, 1)))
def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
print(s2s)
# test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5)
h1 = s2s(h0, g)
assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph
bg = dgl.batch([g, g, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5)
h1 = s2s(h0, bg)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2
def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
print(gap)
# test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5)
h1 = gap(h0, g)
assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5)
h1 = gap(h0, bg)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
def test_simple_pool():
g = dgl.DGLGraph(nx.path_graph(15))
sum_pool = nn.SumPooling()
avg_pool = nn.AvgPooling()
max_pool = nn.MaxPooling()
sort_pool = nn.SortPooling(10) # k = 10
print(sum_pool, avg_pool, max_pool, sort_pool)
# test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5)
h1 = sum_pool(h0, g)
check_close(h1, mx.nd.sum(h0, 0))
h1 = avg_pool(h0, g)
check_close(h1, mx.nd.mean(h0, 0))
h1 = max_pool(h0, g)
check_close(h1, mx.nd.max(h0, 0))
h1 = sort_pool(h0, g)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5)
h1 = sum_pool(h0, bg)
truth = mx.nd.stack(mx.nd.sum(h0[:15], 0),
mx.nd.sum(h0[15:20], 0),
mx.nd.sum(h0[20:35], 0),
mx.nd.sum(h0[35:40], 0),
mx.nd.sum(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = avg_pool(h0, bg)
truth = mx.nd.stack(mx.nd.mean(h0[:15], 0),
mx.nd.mean(h0[15:20], 0),
mx.nd.mean(h0[20:35], 0),
mx.nd.mean(h0[35:40], 0),
mx.nd.mean(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = max_pool(h0, bg)
truth = mx.nd.stack(mx.nd.max(h0[:15], 0),
mx.nd.max(h0[15:20], 0),
mx.nd.max(h0[20:35], 0),
mx.nd.max(h0[35:40], 0),
mx.nd.max(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = sort_pool(h0, bg)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2
def uniform_attention(g, shape): def uniform_attention(g, shape):
a = mx.nd.ones(shape) a = mx.nd.ones(shape)
...@@ -97,3 +181,6 @@ def test_edge_softmax(): ...@@ -97,3 +181,6 @@ def test_edge_softmax():
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_edge_softmax() test_edge_softmax()
test_set2set()
test_glob_att_pool()
test_simple_pool()
...@@ -61,6 +61,124 @@ def test_graph_conv(): ...@@ -61,6 +61,124 @@ def test_graph_conv():
new_weight = conv.weight.data new_weight = conv.weight.data
assert not th.allclose(old_weight, new_weight) assert not th.allclose(old_weight, new_weight)
def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
print(s2s)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 5)
h1 = s2s(h0, g)
assert h1.shape[0] == 10 and h1.dim() == 1
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(11))
g2 = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g1, g2])
h0 = th.rand(bg.number_of_nodes(), 5)
h1 = s2s(h0, bg)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2
def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
print(gap)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 5)
h1 = gap(h0, g)
assert h1.shape[0] == 10 and h1.dim() == 1
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
h0 = th.rand(bg.number_of_nodes(), 5)
h1 = gap(h0, bg)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2
def test_simple_pool():
g = dgl.DGLGraph(nx.path_graph(15))
sum_pool = nn.SumPooling()
avg_pool = nn.AvgPooling()
max_pool = nn.MaxPooling()
sort_pool = nn.SortPooling(10) # k = 10
print(sum_pool, avg_pool, max_pool, sort_pool)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 5)
h1 = sum_pool(h0, g)
assert th.allclose(h1, th.sum(h0, 0))
h1 = avg_pool(h0, g)
assert th.allclose(h1, th.mean(h0, 0))
h1 = max_pool(h0, g)
assert th.allclose(h1, th.max(h0, 0)[0])
h1 = sort_pool(h0, g)
assert h1.shape[0] == 10 * 5 and h1.dim() == 1
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g])
h0 = th.rand(bg.number_of_nodes(), 5)
h1 = sum_pool(h0, bg)
truth = th.stack([th.sum(h0[:15], 0),
th.sum(h0[15:20], 0),
th.sum(h0[20:35], 0),
th.sum(h0[35:40], 0),
th.sum(h0[40:55], 0)], 0)
assert th.allclose(h1, truth)
h1 = avg_pool(h0, bg)
truth = th.stack([th.mean(h0[:15], 0),
th.mean(h0[15:20], 0),
th.mean(h0[20:35], 0),
th.mean(h0[35:40], 0),
th.mean(h0[40:55], 0)], 0)
assert th.allclose(h1, truth)
h1 = max_pool(h0, bg)
truth = th.stack([th.max(h0[:15], 0)[0],
th.max(h0[15:20], 0)[0],
th.max(h0[20:35], 0)[0],
th.max(h0[35:40], 0)[0],
th.max(h0[40:55], 0)[0]], 0)
assert th.allclose(h1, truth)
h1 = sort_pool(h0, bg)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2
def test_set_trans():
g = dgl.DGLGraph(nx.path_graph(15))
st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
print(st_enc_0, st_enc_1, st_dec)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 50)
h1 = st_enc_0(h0, g)
assert h1.shape == h0.shape
h1 = st_enc_1(h0, g)
assert h1.shape == h0.shape
h2 = st_dec(h1, g)
assert h2.shape[0] == 200 and h2.dim() == 1
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5))
g2 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g, g1, g2])
h0 = th.rand(bg.number_of_nodes(), 50)
h1 = st_enc_0(h0, bg)
assert h1.shape == h0.shape
h1 = st_enc_1(h0, bg)
assert h1.shape == h0.shape
h2 = st_dec(h1, bg)
assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2
def uniform_attention(g, shape): def uniform_attention(g, shape):
a = th.ones(shape) a = th.ones(shape)
target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1) target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
...@@ -130,3 +248,7 @@ def test_edge_softmax(): ...@@ -130,3 +248,7 @@ def test_edge_softmax():
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_edge_softmax() test_edge_softmax()
test_set2set()
test_glob_att_pool()
test_simple_pool()
test_set_trans()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment