Unverified Commit 5d3f470b authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] DGL Pooling modules (#669)

* removal doc

* glob

* upd

* rm knn

* add softmax

* upd

* upd

* add broadcast and s2s

* optimize max_on

* forsaken changes to heterograph

* upd

* upd

* upd

* upd

* upd

* bugfix

* upd

* upd

* upd

* upd

* format upd

* upd format

* upd doc

* upd

* import order

* upd

* rm warnings

* fix

* upd test

* upd

* upd

* fix device

* upd

* upd

* upd

* upd

* remove 1.1

* upd

* trigger

* trigger

* add more tests

* fix device

* upd

* upd

* refactor

* fix?

* fix

* upd docstring

* refactor

* upd

* fix

* upd

* upd

* upd

* fix

* upd docs

* add shape

* refactor & upd doc

* upd doc

* upd
parent c3516f1a
...@@ -37,3 +37,9 @@ Graph Readout ...@@ -37,3 +37,9 @@ Graph Readout
mean_edges mean_edges
max_nodes max_nodes
max_edges max_edges
topk_nodes
topk_edges
softmax_nodes
softmax_edges
broadcast_nodes
broadcast_edges
...@@ -11,3 +11,18 @@ dgl.nn.mxnet.conv ...@@ -11,3 +11,18 @@ dgl.nn.mxnet.conv
.. autoclass:: dgl.nn.mxnet.conv.GraphConv .. autoclass:: dgl.nn.mxnet.conv.GraphConv
:members: weight, bias, forward :members: weight, bias, forward
:show-inheritance: :show-inheritance:
dgl.nn.mxnet.glob
-----------------
.. automodule:: dgl.nn.mxnet.glob
:members:
dgl.nn.mxnet.softmax
--------------------
.. automodule:: dgl.nn.mxnet.softmax
.. autoclass:: dgl.nn.mxnet.softmax.EdgeSoftmax
:members: forward
:show-inheritance:
...@@ -12,6 +12,43 @@ dgl.nn.pytorch.conv ...@@ -12,6 +12,43 @@ dgl.nn.pytorch.conv
:members: weight, bias, forward, reset_parameters :members: weight, bias, forward, reset_parameters
:show-inheritance: :show-inheritance:
dgl.nn.pytorch.glob
-------------------
.. automodule:: dgl.nn.pytorch.glob
.. autoclass:: dgl.nn.pytorch.glob.SumPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.AvgPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.MaxPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.SortPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.GlobalAttentionPooling
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.Set2Set
:members: forward
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.SetTransformerEncoder
:members:
:show-inheritance:
.. autoclass:: dgl.nn.pytorch.glob.SetTransformerDecoder
:members:
:show-inheritance:
dgl.nn.pytorch.softmax dgl.nn.pytorch.softmax
---------------------- ----------------------
......
...@@ -74,6 +74,21 @@ def tensor(data, dtype=None): ...@@ -74,6 +74,21 @@ def tensor(data, dtype=None):
""" """
pass pass
def as_scalar(data):
"""Returns a scalar whose value is copied from this array.
Parameters
----------
data : Tensor
The input data
Returns
-------
scalar
The scalar value in the tensor.
"""
pass
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -293,6 +308,21 @@ def sum(input, dim): ...@@ -293,6 +308,21 @@ def sum(input, dim):
""" """
pass pass
def reduce_sum(input):
"""Returns the sum of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def mean(input, dim): def mean(input, dim):
"""Reduce average the input tensor along the given dim. """Reduce average the input tensor along the given dim.
...@@ -310,6 +340,21 @@ def mean(input, dim): ...@@ -310,6 +340,21 @@ def mean(input, dim):
""" """
pass pass
def reduce_mean(input):
"""Returns the average of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def max(input, dim): def max(input, dim):
"""Reduce max the input tensor along the given dim. """Reduce max the input tensor along the given dim.
...@@ -327,6 +372,121 @@ def max(input, dim): ...@@ -327,6 +372,121 @@ def max(input, dim):
""" """
pass pass
def reduce_max(input):
"""Returns the max of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def min(input, dim):
"""Reduce min the input tensor along the given dim.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The reduce dim.
Returns
-------
Tensor
A framework-specific tensor.
"""
pass
def reduce_min(input):
"""Returns the min of all elements in the input tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
A framework-specific tensor with shape (1,)
"""
pass
def argsort(input, dim, descending):
"""Return the indices that would sort the input along the given dim.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The dim to sort along.
descending : bool
Controls the sorting order (False: ascending, True: descending)
Returns
-------
Tensor
A framework-specific tensor.
"""
def topk(input, k, dim, descending=True):
"""Return the k largest elements of the given input tensor along the given dimension.
If descending is False then the k smallest elements are returned.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The dim to sort along.
descending : bool
Controls whether to return largest/smallest elements.
"""
pass
def exp(input):
"""Returns a new tensor with the exponential of the elements of the input tensor `input`.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
Tensor
The output tensor.
"""
pass
def softmax(input, dim=-1):
"""Apply the softmax function on given dimension.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The dimension along which to compute softmax.
Returns
-------
Tensor
The output tensor.
"""
pass
def cat(seq, dim): def cat(seq, dim):
"""Concat the sequence of tensors in the given dimension. """Concat the sequence of tensors in the given dimension.
...@@ -381,6 +541,25 @@ def split(input, sizes_or_sections, dim): ...@@ -381,6 +541,25 @@ def split(input, sizes_or_sections, dim):
""" """
pass pass
def repeat(input, repeats, dim):
"""Repeats elements of an array.
Parameters
----------
input : Tensor
Input data array
repeats : int
The number of repetitions for each element
dim : int
The dim along which to repeat values.
Returns
-------
Tensor
The obtained tensor.
"""
pass
def gather_row(data, row_index): def gather_row(data, row_index):
"""Slice out the data given the row index. """Slice out the data given the row index.
...@@ -398,6 +577,41 @@ def gather_row(data, row_index): ...@@ -398,6 +577,41 @@ def gather_row(data, row_index):
""" """
pass pass
def slice_axis(data, axis, begin, end):
"""Slice along a given axis.
Returns an array slice along a given axis starting from :attr:`begin` index to :attr:`end` index.
Parameters
----------
data : Tensor
The data tensor.
axis : int
The axis along to slice the tensor.
begin : int
Indicates the begin index.
end : int
Indicates the end index.
Returns:
--------
Tensor
The sliced tensor.
"""
pass
def take(data, indices, dim):
"""Takes elements from an input array along the given dim.
Parameters
----------
data : Tensor
The data tensor.
indices : Tensor
The indices tensor.
dim : Tensor
The dimension to gather along.
"""
pass
def narrow_row(x, start, stop): def narrow_row(x, start, stop):
"""Narrow down the tensor along the first dimension. """Narrow down the tensor along the first dimension.
...@@ -563,6 +777,50 @@ def ones(shape, dtype, ctx): ...@@ -563,6 +777,50 @@ def ones(shape, dtype, ctx):
""" """
pass pass
def pad_packed_tensor(input, lengths, value, l_min=None):
"""Pads a packed batch of variable length tensors with given value.
Parameters
----------
input : Tensor
The input tensor with shape :math:`(N, *)`
lengths : list or tensor
The array of tensor lengths (of the first dimension) :math:`L`.
It should satisfy :math:`\sum_{i=1}^{B}L_i = N`,
where :math:`B` is the length of :math:`L`.
value : float
The value to fill in the tensor.
l_min : int or None, defaults to None.
The minimum length each tensor need to be padded to, if set to None,
then there is no minimum length requirement.
Returns
-------
Tensor
The obtained tensor with shape :math:`(B, \max(\max_i(L_i), l_{min}), *)`
"""
pass
def pack_padded_tensor(input, lengths):
"""Packs a tensor containing padded sequence of variable length.
Parameters
----------
input : Tensor
The input tensor with shape :math:`(B, L, *)`, where :math:`B` is
the batch size and :math:`L` is the maximum length of the batch.
lengths : list or tensor
The array of tensor lengths (of the first dimension) :math:`L`.
:math:`\max_i(L_i)` should equal :math:`L`.
Returns
-------
Tensor
The obtained tensor with shape :math:`(N, *)` where
:math:`N = \sum_{i=1}^{B}L_i`
"""
pass
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
"""Computes the sum along segments of a tensor. """Computes the sum along segments of a tensor.
......
...@@ -6,6 +6,7 @@ import numpy as np ...@@ -6,6 +6,7 @@ import numpy as np
import mxnet as mx import mxnet as mx
import mxnet.ndarray as nd import mxnet.ndarray as nd
import numbers import numbers
import builtins
from ... import ndarray as dglnd from ... import ndarray as dglnd
from ... import kernel as K from ... import kernel as K
...@@ -38,6 +39,9 @@ def tensor(data, dtype=None): ...@@ -38,6 +39,9 @@ def tensor(data, dtype=None):
dtype = np.float32 dtype = np.float32
return nd.array(data, dtype=dtype) return nd.array(data, dtype=dtype)
def as_scalar(data):
return data.asscalar()
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -112,12 +116,41 @@ def copy_to(input, ctx): ...@@ -112,12 +116,41 @@ def copy_to(input, ctx):
def sum(input, dim): def sum(input, dim):
return nd.sum(input, axis=dim) return nd.sum(input, axis=dim)
def reduce_sum(input):
return input.sum()
def mean(input, dim): def mean(input, dim):
return nd.mean(input, axis=dim) return nd.mean(input, axis=dim)
def reduce_mean(input):
return input.mean()
def max(input, dim): def max(input, dim):
return nd.max(input, axis=dim) return nd.max(input, axis=dim)
def reduce_max(input):
return input.max()
def min(input, dim):
return nd.min(input, axis=dim)
def reduce_min(input):
return input.min()
def topk(input, k, dim, descending=True):
return nd.topk(input, axis=dim, k=k, ret_typ='value', is_ascend=not descending)
def argsort(input, dim, descending):
idx = nd.argsort(input, dim, is_ascend=not descending)
idx = nd.cast(idx, dtype='int64')
return idx
def exp(input):
return nd.exp(input)
def softmax(input, dim=-1):
return nd.softmax(input, axis=dim)
def cat(seq, dim): def cat(seq, dim):
return nd.concat(*seq, dim=dim) return nd.concat(*seq, dim=dim)
...@@ -143,6 +176,9 @@ def split(x, sizes_or_sections, dim): ...@@ -143,6 +176,9 @@ def split(x, sizes_or_sections, dim):
else: else:
return nd.split(x, sizes_or_sections, axis=dim) return nd.split(x, sizes_or_sections, axis=dim)
def repeat(input, repeats, dim):
return nd.repeat(input, repeats, axis=dim)
def gather_row(data, row_index): def gather_row(data, row_index):
# MXNet workaround for empty row index # MXNet workaround for empty row index
if len(row_index) == 0: if len(row_index) == 0:
...@@ -153,6 +189,17 @@ def gather_row(data, row_index): ...@@ -153,6 +189,17 @@ def gather_row(data, row_index):
else: else:
return data[row_index,] return data[row_index,]
def slice_axis(data, axis, begin, end):
dim = data.shape[axis]
if begin < 0:
begin += dim
if end <= 0:
end += dim
return nd.slice_axis(data, axis, begin, end)
def take(data, indices, dim):
return nd.take(data, indices, dim)
def narrow_row(data, start, stop): def narrow_row(data, start, stop):
return data[start:stop] return data[start:stop]
...@@ -181,6 +228,35 @@ def zeros_like(input): ...@@ -181,6 +228,35 @@ def zeros_like(input):
def ones(shape, dtype, ctx): def ones(shape, dtype, ctx):
return nd.ones(shape, dtype=dtype, ctx=ctx) return nd.ones(shape, dtype=dtype, ctx=ctx)
def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape
if isinstance(lengths, nd.NDArray):
max_len = as_scalar(input.max())
else:
max_len = builtins.max(lengths)
if l_min is not None:
max_len = builtins.max(max_len, l_min)
batch_size = len(lengths)
ctx = input.context
dtype = input.dtype
x = nd.full((batch_size * max_len, *old_shape[1:]), value, ctx=ctx, dtype=dtype)
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = nd.array(index, ctx=ctx)
return scatter_row(x, index, input).reshape(batch_size, max_len, *old_shape[1:])
def pack_padded_tensor(input, lengths):
batch_size, max_len = input.shape[:2]
ctx = input.context
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = nd.array(index, ctx=ctx)
return gather_row(input.reshape(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
# TODO: support other dimensions # TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment sum on first dimension' assert dim == 0, 'MXNet only supports segment sum on first dimension'
......
...@@ -22,6 +22,11 @@ def cpu(): ...@@ -22,6 +22,11 @@ def cpu():
def tensor(data, dtype=None): def tensor(data, dtype=None):
return np.array(data, dtype) return np.array(data, dtype)
def as_scalar(data):
if data.dim() > 1:
raise ValueError('The data must have shape (1,).')
return data[0]
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -75,9 +80,46 @@ def copy_to(input, ctx): ...@@ -75,9 +80,46 @@ def copy_to(input, ctx):
def sum(input, dim): def sum(input, dim):
return np.sum(input, axis=dim) return np.sum(input, axis=dim)
def reduce_sum(input):
dtype = input.dtype
return np.array(input.sum(), dtype=dtype)
def mean(input, dim):
return np.mean(input, axis=dim)
def reduce_mean(input):
dtype = input.dtype
return np.array(input.mean(), dtype=dtype)
def max(input, dim): def max(input, dim):
return np.max(input, axis=dim) return np.max(input, axis=dim)
def reduce_max(input):
dtype = input.dtype
return np.array(input.max(), dtype=dtype)
def min(input, dim):
return np.min(input, axis=dim)
def reduce_min(input):
dtype = input.dtype
return np.array(input.min(), dtype=dtype)
def argsort(input, dim, descending):
if descending:
return np.argsort(-input, axis=dim)
return np.argsort(input, axis=dim)
def exp(input):
return np.exp(input)
def softmax(input, dim=-1):
max_val = input.max(axis=dim)
minus_max = input - np.expand_dims(max_val, axis=dim)
exp_val = np.exp(minus_max)
sum_val = np.sum(exp_val, axis=dim)
return exp_val / np.expand_dims(sum_val, axis=dim)
def cat(seq, dim): def cat(seq, dim):
return np.concatenate(seq, axis=dim) return np.concatenate(seq, axis=dim)
...@@ -92,9 +134,20 @@ def split(input, sizes_or_sections, dim): ...@@ -92,9 +134,20 @@ def split(input, sizes_or_sections, dim):
idx = np.cumsum(sizes_or_sections)[0:-1] idx = np.cumsum(sizes_or_sections)[0:-1]
return np.split(input, idx, axis=dim) return np.split(input, idx, axis=dim)
def repeat(input, repeats, dim):
return np.repeat(input, repeats, axis=dim)
def gather_row(data, row_index): def gather_row(data, row_index):
return data[row_index] return data[row_index]
def slice_axis(data, axis, begin, end):
if begin >= end:
raise IndexError("Begin index ({}) equals or greater than end index ({})".format(begin, end))
return np.take(data, np.arange(begin, end), axis=axis)
def take(data, indices, dim):
return np.take(data, indices, axis=dim)
def scatter_row(data, row_index, value): def scatter_row(data, row_index, value):
# NOTE: inplace instead of out-place # NOTE: inplace instead of out-place
data[row_index] = value data[row_index] = value
......
...@@ -3,6 +3,7 @@ from __future__ import absolute_import ...@@ -3,6 +3,7 @@ from __future__ import absolute_import
from distutils.version import LooseVersion from distutils.version import LooseVersion
import torch as th import torch as th
import builtins
from torch.utils import dlpack from torch.utils import dlpack
from ... import ndarray as nd from ... import ndarray as nd
...@@ -26,6 +27,9 @@ def cpu(): ...@@ -26,6 +27,9 @@ def cpu():
def tensor(data, dtype=None): def tensor(data, dtype=None):
return th.tensor(data, dtype=dtype) return th.tensor(data, dtype=dtype)
def as_scalar(data):
return data.item()
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -90,13 +94,41 @@ def copy_to(input, ctx): ...@@ -90,13 +94,41 @@ def copy_to(input, ctx):
def sum(input, dim): def sum(input, dim):
return th.sum(input, dim=dim) return th.sum(input, dim=dim)
def reduce_sum(input):
return input.sum()
def mean(input, dim): def mean(input, dim):
return th.mean(input, dim=dim) return th.mean(input, dim=dim)
def reduce_mean(input):
return input.mean()
def max(input, dim): def max(input, dim):
# NOTE: the second argmax array is not returned # NOTE: the second argmax array is not returned
return th.max(input, dim=dim)[0] return th.max(input, dim=dim)[0]
def reduce_max(input):
return input.max()
def min(input, dim):
# NOTE: the second argmin array is not returned
return th.min(input, dim=dim)[0]
def reduce_min(input):
return input.min()
def argsort(input, dim, descending):
return th.argsort(input, dim=dim, descending=descending)
def topk(input, k, dim, descending=True):
return th.topk(input, k, dim, largest=descending)[0]
def exp(input):
return th.exp(input)
def softmax(input, dim=-1):
return th.softmax(input, dim=dim)
def cat(seq, dim): def cat(seq, dim):
return th.cat(seq, dim=dim) return th.cat(seq, dim=dim)
...@@ -106,9 +138,29 @@ def stack(seq, dim): ...@@ -106,9 +138,29 @@ def stack(seq, dim):
def split(input, sizes_or_sections, dim): def split(input, sizes_or_sections, dim):
return th.split(input, sizes_or_sections, dim) return th.split(input, sizes_or_sections, dim)
def repeat(input, repeats, dim):
# return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1
if dim < 0:
dim += input.dim()
return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1)
def gather_row(data, row_index): def gather_row(data, row_index):
return th.index_select(data, 0, row_index) return th.index_select(data, 0, row_index)
def slice_axis(data, axis, begin, end):
dim = data.shape[axis]
if begin < 0:
begin += dim
if end <= 0:
end += dim
if begin >= end:
raise IndexError("Begin index ({}) equals or greater than end index ({})".format(begin, end))
return th.index_select(data, axis, th.arange(begin, end, device=data.device))
def take(data, indices, dim):
new_shape = data.shape[:dim] + indices.shape + data.shape[dim+1:]
return th.index_select(data, dim, indices.view(-1)).view(new_shape)
def narrow_row(x, start, stop): def narrow_row(x, start, stop):
return x[start:stop] return x[start:stop]
...@@ -136,6 +188,35 @@ def zeros_like(input): ...@@ -136,6 +188,35 @@ def zeros_like(input):
def ones(shape, dtype, ctx): def ones(shape, dtype, ctx):
return th.ones(shape, dtype=dtype, device=ctx) return th.ones(shape, dtype=dtype, device=ctx)
def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape
if isinstance(lengths, th.Tensor):
max_len = as_scalar(lengths.max())
else:
max_len = builtins.max(lengths)
if l_min is not None:
max_len = builtins.max(max_len, l_min)
batch_size = len(lengths)
device = input.device
x = input.new(batch_size * max_len, *old_shape[1:])
x.fill_(value)
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = th.tensor(index).to(device)
return scatter_row(x, index, input).view(batch_size, max_len, *old_shape[1:])
def pack_padded_tensor(input, lengths):
batch_size, max_len = input.shape[:2]
device = input.device
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = th.tensor(index).to(device)
return gather_row(input.view(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input) y = th.zeros(n_segs, *input.shape[1:]).to(input)
seg_id = seg_id.view((-1,) + (1,) * (input.dim() - 1)).expand_as(input) seg_id = seg_id.view((-1,) + (1,) * (input.dim() - 1)).expand_as(input)
......
...@@ -13,7 +13,8 @@ from . import utils ...@@ -13,7 +13,8 @@ from . import utils
__all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split', __all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split',
'sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges', 'sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges',
'max_nodes', 'max_edges'] 'max_nodes', 'max_edges', 'softmax_nodes', 'softmax_edges',
'broadcast_nodes', 'broadcast_edges', 'topk_nodes', 'topk_edges']
class BatchedDGLGraph(DGLGraph): class BatchedDGLGraph(DGLGraph):
"""Class for batched DGL graphs. """Class for batched DGL graphs.
...@@ -375,7 +376,7 @@ def _sum_on(graph, typestr, feat, weight): ...@@ -375,7 +376,7 @@ def _sum_on(graph, typestr, feat, weight):
Returns Returns
------- -------
Tensor tensor
The (weighted) summed node or edge features. The (weighted) summed node or edge features.
""" """
data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr] data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
...@@ -390,7 +391,6 @@ def _sum_on(graph, typestr, feat, weight): ...@@ -390,7 +391,6 @@ def _sum_on(graph, typestr, feat, weight):
if isinstance(graph, BatchedDGLGraph): if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr) batch_num_objs = getattr(graph, batch_num_objs_attr)
seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs)) seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(feat)) seg_id = F.copy_to(seg_id, F.context(feat))
y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0) y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0)
...@@ -555,7 +555,7 @@ def _mean_on(graph, typestr, feat, weight): ...@@ -555,7 +555,7 @@ def _mean_on(graph, typestr, feat, weight):
Returns Returns
------- -------
Tensor tensor
The (weighted) summed node or edge features. The (weighted) summed node or edge features.
""" """
data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr] data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
...@@ -570,7 +570,6 @@ def _mean_on(graph, typestr, feat, weight): ...@@ -570,7 +570,6 @@ def _mean_on(graph, typestr, feat, weight):
if isinstance(graph, BatchedDGLGraph): if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr) batch_num_objs = getattr(graph, batch_num_objs_attr)
seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs)) seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(feat)) seg_id = F.copy_to(seg_id, F.context(feat))
if weight is not None: if weight is not None:
...@@ -729,7 +728,7 @@ def mean_edges(graph, feat, weight=None): ...@@ -729,7 +728,7 @@ def mean_edges(graph, feat, weight=None):
def _max_on(graph, typestr, feat): def _max_on(graph, typestr, feat):
"""Internal function to take elementwise maximum """Internal function to take elementwise maximum
over node or edge features. over node or edge features.
Parameters Parameters
---------- ----------
...@@ -742,29 +741,186 @@ def _max_on(graph, typestr, feat): ...@@ -742,29 +741,186 @@ def _max_on(graph, typestr, feat):
Returns Returns
------- -------
Tensor tensor
The (weighted) summed node or edge features. The (weighted) summed node or edge features.
""" """
data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr] data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
data = getattr(graph, data_attr) data = getattr(graph, data_attr)
feat = data[feat] feat = data[feat]
# TODO: the current solution pads the different graph sizes to the same,
# a more efficient way is to use segment max, we need to implement it in
# the future.
if isinstance(graph, BatchedDGLGraph): if isinstance(graph, BatchedDGLGraph):
batch_num_objs = getattr(graph, batch_num_objs_attr) batch_num_objs = getattr(graph, batch_num_objs_attr)
max_readout_list = [] feat = F.pad_packed_tensor(feat, batch_num_objs, -float('inf'))
first = 0 return F.max(feat, 1)
for num_obj in batch_num_objs:
if num_obj == 0:
max_readout_list.append(F.zeros(F.shape(feat)[1:],
F.dtype(feat),
F.context(feat)))
continue
max_readout_list.append(F.max(feat[first:first+num_obj], 0))
first += num_obj
return F.stack(max_readout_list, 0)
else: else:
return F.max(feat, 0) return F.max(feat, 0)
def _softmax_on(graph, typestr, feat):
"""Internal function of applying batch-wise graph-level softmax
over node or edge features of a given field.
Parameters
----------
graph : DGLGraph
The graph
typestr : str
'nodes' or 'edges'
feat : str
The feature field name.
Returns
-------
tensor
The obtained tensor.
"""
data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
data = getattr(graph, data_attr)
feat = data[feat]
# TODO: the current solution pads the different graph sizes to the same,
# a more efficient way is to use segment sum/max, we need to implement
# it in the future.
if isinstance(graph, BatchedDGLGraph):
batch_num_objs = getattr(graph, batch_num_objs_attr)
feat = F.pad_packed_tensor(feat, batch_num_objs, -float('inf'))
feat = F.softmax(feat, 1)
return F.pack_padded_tensor(feat, batch_num_objs)
else:
return F.softmax(feat, 0)
def _broadcast_on(graph, typestr, feat_data):
"""Internal function of broadcasting features to all nodes/edges.
Parameters
----------
graph : DGLGraph
The graph
typestr : str
'nodes' or 'edges'
feat_data : tensor
The feature to broadcast. Tensor shape is :math:`(*)` for single graph,
and :math:`(B, *)` for batched graph.
Returns
-------
tensor
The node/edge features tensor with shape :math:`(N, *)`.
"""
_, batch_num_objs_attr, num_objs_attr = READOUT_ON_ATTRS[typestr]
if isinstance(graph, BatchedDGLGraph):
batch_num_objs = getattr(graph, batch_num_objs_attr)
index = []
for i, num_obj in enumerate(batch_num_objs):
index.extend([i] * num_obj)
ctx = F.context(feat_data)
index = F.copy_to(F.tensor(index), ctx)
return F.gather_row(feat_data, index)
else:
num_objs = getattr(graph, num_objs_attr)()
if F.ndim(feat_data) == 1:
feat_data = F.unsqueeze(feat_data, 0)
return F.cat([feat_data] * num_objs, 0)
def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
"""Internal function to take graph-wise top-k node/edge features of
field :attr:`feat` in :attr:`graph` ranked by keys at given
index :attr:`idx`. If :attr:`descending` is set to False, return the
k smallest elements instead.
If idx is set to None, the function would return top-k value of all
indices, which is equivalent to calling `th.topk(graph.ndata[feat], dim=0)`
for each example of the input graph.
Parameters
---------
graph : DGLGraph
The graph
typestr : str
'nodes' or 'edges'
feat : str
The feature field name.
k : int
The :math:`k` in "top-:math`k`".
descending : bool
Controls whether to return the largest or smallest elements,
defaults to True.
idx : int or None, defaults to None
The key index we sort :attr:`feat` on, if set to None, we sort
the whole :attr:`feat`.
Returns
-------
tuple of tensors:
The first tensor returns top-k features of the given graph with
shape :math:`(K, D)`, if the input graph is a BatchedDGLGraph,
a tensor with shape :math:`(B, K, D)` would be returned, where
:math:`B` is the batch size.
The second tensor returns the top-k indices of the given graph
with shape :math:`(K)`, if the input graph is a BatchedDGLGraph,
a tensor with shape :math:`(B, K)` would be returned, where
:math:`B` is the batch size.
Notes
-----
If an example has :math:`n` nodes/edges and :math:`n<k`, in the first
returned tensor the :math:`n+1` to :math:`k`th rows would be padded
with all zero; in the second returned tensor, the behavior of :math:`n+1`
to :math:`k`th elements is not defined.
"""
data_attr, batch_num_objs_attr, num_objs_attr = READOUT_ON_ATTRS[typestr]
data = getattr(graph, data_attr)
if F.ndim(data[feat]) > 2:
raise DGLError('The {} feature `{}` should have dimension less than or'
' equal to 2'.format(typestr, feat))
feat = data[feat]
hidden_size = F.shape(feat)[-1]
if isinstance(graph, BatchedDGLGraph):
batch_num_objs = getattr(graph, batch_num_objs_attr)
batch_size = len(batch_num_objs)
else:
batch_num_objs = [getattr(graph, num_objs_attr)()]
batch_size = 1
length = max(max(batch_num_objs), k)
fill_val = -float('inf') if descending else float('inf')
feat_ = F.pad_packed_tensor(feat, batch_num_objs, fill_val, l_min=k)
if idx is not None:
keys = F.squeeze(F.slice_axis(feat_, -1, idx, idx+1), -1)
order = F.argsort(keys, -1, descending=descending)
else:
order = F.argsort(feat_, 1, descending=descending)
topk_indices = F.slice_axis(order, 1, 0, k)
# zero padding
feat_ = F.pad_packed_tensor(feat, batch_num_objs, 0, l_min=k)
if idx is not None:
feat_ = F.reshape(feat_, (batch_size * length, -1))
shift = F.repeat(F.arange(0, batch_size) * length, k, -1)
shift = F.copy_to(shift, F.context(feat))
topk_indices_ = F.reshape(topk_indices, (-1,)) + shift
else:
feat_ = F.reshape(feat_, (-1,))
shift = F.repeat(F.arange(0, batch_size), k * hidden_size, -1) * length * hidden_size +\
F.cat([F.arange(0, hidden_size)] * batch_size * k, -1)
shift = F.copy_to(shift, F.context(feat))
topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift
if isinstance(graph, BatchedDGLGraph):
return F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1)),\
topk_indices
else:
return F.reshape(F.gather_row(feat_, topk_indices_), (k, -1)),\
topk_indices
def max_nodes(graph, feat): def max_nodes(graph, feat):
"""Take elementwise maximum over all the values of node field """Take elementwise maximum over all the values of node field
:attr:`feat` in :attr:`graph` :attr:`feat` in :attr:`graph`
...@@ -781,13 +937,43 @@ def max_nodes(graph, feat): ...@@ -781,13 +937,43 @@ def max_nodes(graph, feat):
tensor tensor
The tensor obtained. The tensor obtained.
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraph` objects and initialize their
node features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(2)
>>> g1.ndata['h'] = th.tensor([[1.], [2.]])
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(3)
>>> g2.ndata['h'] = th.tensor([[1.], [2.], [3.]])
Max over node attribute :attr:`h` in a batched graph.
>>> bg = dgl.batch([g1, g2], node_attrs='h')
>>> dgl.max_nodes(bg, 'h')
tensor([[2.], # max(1, 2)
[3.]]) # max(1, 2, 3)
Max over node attribute :attr:`h` in a single graph.
>>> dgl.max_nodes(g1, 'h')
tensor([2.])
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension. returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no nodes, corresponding example in the batch. If an example has no nodes,
a zero tensor with the same shape is returned at the corresponding row. a tensor filed with -inf of the same shape is returned at the
corresponding row.
""" """
return _max_on(graph, 'nodes', feat) return _max_on(graph, 'nodes', feat)
...@@ -807,12 +993,514 @@ def max_edges(graph, feat): ...@@ -807,12 +993,514 @@ def max_edges(graph, feat):
tensor tensor
The tensor obtained. The tensor obtained.
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraph` objects and initialize their
edge features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(2)
>>> g1.add_edges([0, 1], [1, 0])
>>> g1.edata['h'] = th.tensor([[1.], [2.]])
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(3)
>>> g2.add_edges([0, 1, 2], [1, 2, 0])
>>> g2.edata['h'] = th.tensor([[1.], [2.], [3.]])
Max over edge attribute :attr:`h` in a batched graph.
>>> bg = dgl.batch([g1, g2], edge_attrs='h')
>>> dgl.max_edges(bg, 'h')
tensor([[2.], # max(1, 2)
[3.]]) # max(1, 2, 3)
Max over edge attribute :attr:`h` in a single graph.
>>> dgl.max_edges(g1, 'h')
tensor([2.])
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension. returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges, corresponding example in the batch. If an example has no edges,
a zero tensor with the same shape is returned at the corresponding row. a tensor filled with -inf of the same shape is returned at the
corresponding row.
""" """
return _max_on(graph, 'edges', feat) return _max_on(graph, 'edges', feat)
def softmax_nodes(graph, feat):
"""Apply batch-wise graph-level softmax over all the values of node field
:attr:`feat` in :attr:`graph`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : str
The feature field.
Returns
-------
tensor
The tensor obtained.
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraph` objects and initialize their
node features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(2)
>>> g1.ndata['h'] = th.tensor([[1., 0.], [2., 0.]])
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(3)
>>> g2.ndata['h'] = th.tensor([[1., 0.], [2., 0.], [3., 0.]])
Softmax over node attribute :attr:`h` in a batched graph.
>>> bg = dgl.batch([g1, g2], node_attrs='h')
>>> dgl.softmax_nodes(bg, 'h')
tensor([[0.2689, 0.5000], # [0.2689, 0.7311] = softmax([1., 2.])
[0.7311, 0.5000], # [0.5000, 0.5000] = softmax([0., 0.])
[0.0900, 0.3333], # [0.0900, 0.2447, 0.6652] = softmax([1., 2., 3.])
[0.2447, 0.3333], # [0.3333, 0.3333, 0.3333] = softmax([0., 0., 0.])
[0.6652, 0.3333]])
Softmax over node attribute :attr:`h` in a single graph.
>>> dgl.softmax_nodes(g1, 'h')
tensor([[0.2689, 0.5000], # [0.2689, 0.7311] = softmax([1., 2.])
[0.7311, 0.5000]]), # [0.5000, 0.5000] = softmax([0., 0.])
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, the softmax is applied at
each example in the batch.
"""
return _softmax_on(graph, 'nodes', feat)
def softmax_edges(graph, feat):
"""Apply batch-wise graph-level softmax over all the values of edge field
:attr:`feat` in :attr:`graph`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : str
The feature field.
Returns
-------
tensor
The tensor obtained.
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraph` objects and initialize their
edge features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(2)
>>> g1.add_edges([0, 1], [1, 0])
>>> g1.edata['h'] = th.tensor([[1., 0.], [2., 0.]])
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(3)
>>> g2.add_edges([0, 1, 2], [1, 2, 0])
>>> g2.edata['h'] = th.tensor([[1., 0.], [2., 0.], [3., 0.]])
Softmax over edge attribute :attr:`h` in a batched graph.
>>> bg = dgl.batch([g1, g2], edge_attrs='h')
>>> dgl.softmax_edges(bg, 'h')
tensor([[0.2689, 0.5000], # [0.2689, 0.7311] = softmax([1., 2.])
[0.7311, 0.5000], # [0.5000, 0.5000] = softmax([0., 0.])
[0.0900, 0.3333], # [0.0900, 0.2447, 0.6652] = softmax([1., 2., 3.])
[0.2447, 0.3333], # [0.3333, 0.3333, 0.3333] = softmax([0., 0., 0.])
[0.6652, 0.3333]])
Softmax over edge attribute :attr:`h` in a single graph.
>>> dgl.softmax_edges(g1, 'h')
tensor([[0.2689, 0.5000], # [0.2689, 0.7311] = softmax([1., 2.])
[0.7311, 0.5000]]), # [0.5000, 0.5000] = softmax([0., 0.])
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, the softmax is applied at each
example in the batch.
"""
return _softmax_on(graph, 'edges', feat)
def broadcast_nodes(graph, feat_data):
"""Broadcast :attr:`feat_data` to all nodes in :attr:`graph`, and return a
tensor of node features.
Parameters
----------
graph : DGLGraph or BatcheDGLGraph
The graph.
feat_data : tensor
The feature to broadcast. Tensor shape is :math:`(*)` for single graph, and
:math:`(B, *)` for batched graph.
Returns
-------
tensor
The node features tensor with shape :math:`(N, *)`.
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraph` objects and initialize their
node features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(2)
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(3)
>>> bg = dgl.batch([g1, g2])
>>> feat = th.rand(2, 5)
>>> feat
tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])
Broadcast feature to all nodes in the batched graph, feat[i] is broadcast to nodes
in the i-th example in the batch.
>>> dgl.broadcast_nodes(bg, feat)
tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014],
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014],
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])
Broadcast feature to all nodes in the batched graph.
>>> dgl.broadcast_nodes(g1, feat[0])
tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
[0.4325, 0.7710, 0.5541, 0.0544, 0.9368]])
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, feat[i] is broadcast to the nodes
in i-th example in the batch.
"""
return _broadcast_on(graph, 'nodes', feat_data)
def broadcast_edges(graph, feat_data):
"""Broadcast :attr:`feat_data` to all edges in :attr:`graph`, and return a
tensor of edge features.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat_data : tensor
The feature to broadcast. Tensor shape is :math:`(*)` for single
graph, and :math:`(B, *)` for batched graph.
Returns
-------
tensor
The edge features tensor with shape :math:`(E, *)`
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraph` objects and initialize their
edge features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(2)
>>> g1.add_edges([0, 1], [1, 0])
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(3)
>>> g2.add_edges([0, 1, 2], [1, 2, 0])
>>> bg = dgl.batch([g1, g2])
>>> feat = th.rand(2, 5)
>>> feat
tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])
Broadcast feature to all edges in the batched graph, feat[i] is broadcast to edges
in the i-th example in the batch.
>>> dgl.broadcast_edges(bg, feat)
tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014],
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014],
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])
Broadcast feature to all edges in the batched graph.
>>> dgl.broadcast_edges(g1, feat[0])
tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
[0.4325, 0.7710, 0.5541, 0.0544, 0.9368]])
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, feat[i] is broadcast to
the edges in i-th example in the batch.
"""
return _broadcast_on(graph, 'edges', feat_data)
def topk_nodes(graph, feat, k, descending=True, idx=None):
"""Return graph-wise top-k node features of field :attr:`feat` in
:attr:`graph` ranked by keys at given index :attr:`idx`. If :attr:
`descending` is set to False, return the k smallest elements instead.
If idx is set to None, the function would return top-k value of all
indices, which is equivalent to calling
:code:`torch.topk(graph.ndata[feat], dim=0)`
for each example of the input graph.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : str
The feature field.
k : int
The k in "top-k"
descending : bool
Controls whether to return the largest or smallest elements.
idx : int or None, defaults to None
The index of keys we rank :attr:`feat` on, if set to None, we sort
the whole :attr:`feat`.
Returns
-------
tuple of tensors
The first tensor returns top-k node features of the given graph
with shape :math:`(K, D)`, if the input graph is a BatchedDGLGraph,
a tensor with shape :math:`(B, K, D)` would be returned, where
:math:`B` is the batch size.
The second tensor returns the top-k edge indices of the given
graph with shape :math:`(K)`(:math:`(K, D)` if idx is set to None),
if the input graph is a BatchedDGLGraph, a tensor with shape
:math:`(B, K)`(:math:`(B, K, D)` if` idx is set to None) would be
returned, where :math:`B` is the batch size.
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraph` objects and initialize their
node features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(4)
>>> g1.ndata['h'] = th.rand(4, 5)
>>> g1.ndata['h']
tensor([[0.0297, 0.8307, 0.9140, 0.6702, 0.3346],
[0.5901, 0.3030, 0.9280, 0.6893, 0.7997],
[0.0880, 0.6515, 0.4451, 0.7507, 0.5297],
[0.5171, 0.6379, 0.2695, 0.8954, 0.5197]])
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(5)
>>> g2.ndata['h'] = th.rand(5, 5)
>>> g2.ndata['h']
tensor([[0.3168, 0.3174, 0.5303, 0.0804, 0.3808],
[0.1323, 0.2766, 0.4318, 0.6114, 0.1458],
[0.1752, 0.9105, 0.5692, 0.8489, 0.0539],
[0.1931, 0.4954, 0.3455, 0.3934, 0.0857],
[0.5065, 0.5182, 0.5418, 0.1520, 0.3872]])
Top-k over node attribute :attr:`h` in a batched graph.
>>> bg = dgl.batch([g1, g2], node_attrs='h')
>>> dgl.topk_nodes(bg, 'h', 3)
(tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],
[0.5171, 0.6515, 0.9140, 0.7507, 0.5297],
[0.0880, 0.6379, 0.4451, 0.6893, 0.5197]],
[[0.5065, 0.9105, 0.5692, 0.8489, 0.3872],
[0.3168, 0.5182, 0.5418, 0.6114, 0.3808],
[0.1931, 0.4954, 0.5303, 0.3934, 0.1458]]]), tensor([[[1, 0, 1, 3, 1],
[3, 2, 0, 2, 2],
[2, 3, 2, 1, 3]],
[[4, 2, 2, 2, 4],
[0, 4, 4, 1, 0],
[3, 3, 0, 3, 1]]]))
Top-k over node attribute :attr:`h` along index -1 in a batched graph.
(used in SortPooling)
>>> dgl.topk_nodes(bg, 'h', 3, idx=-1)
(tensor([[[0.5901, 0.3030, 0.9280, 0.6893, 0.7997],
[0.0880, 0.6515, 0.4451, 0.7507, 0.5297],
[0.5171, 0.6379, 0.2695, 0.8954, 0.5197]],
[[0.5065, 0.5182, 0.5418, 0.1520, 0.3872],
[0.3168, 0.3174, 0.5303, 0.0804, 0.3808],
[0.1323, 0.2766, 0.4318, 0.6114, 0.1458]]]), tensor([[1, 2, 3],
[4, 0, 1]]))
Top-k over node attribute :attr:`h` in a single graph.
>>> dgl.topk_nodes(g1, 'h', 3)
(tensor([[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],
[0.5171, 0.6515, 0.9140, 0.7507, 0.5297],
[0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]), tensor([[[1, 0, 1, 3, 1],
[3, 2, 0, 2, 2],
[2, 3, 2, 1, 3]]]))
Notes
-----
If an example has :math:`n` nodes and :math:`n<k`, in the first
returned tensor the :math:`n+1` to :math:`k`th rows would be padded
with all zero; in the second returned tensor, the behavior of :math:`n+1`
to :math:`k`th elements is not defined.
"""
return _topk_on(graph, 'nodes', feat, k, descending=descending, idx=idx)
def topk_edges(graph, feat, k, descending=True, idx=None):
"""Return graph-wise top-k edge features of field :attr:`feat` in
:attr:`graph` ranked by keys at given index :attr:`idx`. If
:attr:`descending` is set to False, return the k smallest elements
instead.
If idx is set to None, the function would return top-k value of all
indices, which is equivalent to calling
:code:`torch.topk(graph.edata[feat], dim=0)`
for each example of the input graph.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : str
The feature field.
k : int
The k in "top-k".
descending : bool
Controls whether to return the largest or smallest elements.
idx : int or None, defaults to None
The key index we sort :attr:`feat` on, if set to None, we sort
the whole :attr:`feat`.
Returns
-------
tuple of tensors
The first tensor returns top-k edge features of the given graph
with shape :math:`(K, D)`, if the input graph is a BatchedDGLGraph,
a tensor with shape :math:`(B, K, D)` would be returned, where
:math:`B` is the batch size.
The second tensor returns the top-k edge indices of the given
graph with shape :math:`(K)`(:math:`(K, D)` if idx is set to None),
if the input graph is a BatchedDGLGraph, a tensor with shape
:math:`(B, K)`(:math:`(B, K, D)` if` idx is set to None) would be
returned, where :math:`B` is the batch size.
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraph` objects and initialize their
edge features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(4)
>>> g1.add_edges([0, 1, 2, 3], [1, 2, 3, 0])
>>> g1.edata['h'] = th.rand(4, 5)
>>> g1.edata['h']
tensor([[0.0297, 0.8307, 0.9140, 0.6702, 0.3346],
[0.5901, 0.3030, 0.9280, 0.6893, 0.7997],
[0.0880, 0.6515, 0.4451, 0.7507, 0.5297],
[0.5171, 0.6379, 0.2695, 0.8954, 0.5197]])
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(5)
>>> g2.add_edges([0, 1, 2, 3, 4], [1, 2, 3, 4, 0])
>>> g2.edata['h'] = th.rand(5, 5)
>>> g2.edata['h']
tensor([[0.3168, 0.3174, 0.5303, 0.0804, 0.3808],
[0.1323, 0.2766, 0.4318, 0.6114, 0.1458],
[0.1752, 0.9105, 0.5692, 0.8489, 0.0539],
[0.1931, 0.4954, 0.3455, 0.3934, 0.0857],
[0.5065, 0.5182, 0.5418, 0.1520, 0.3872]])
Top-k over edge attribute :attr:`h` in a batched graph.
>>> bg = dgl.batch([g1, g2], edge_attrs='h')
>>> dgl.topk_edges(bg, 'h', 3)
(tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],
[0.5171, 0.6515, 0.9140, 0.7507, 0.5297],
[0.0880, 0.6379, 0.4451, 0.6893, 0.5197]],
[[0.5065, 0.9105, 0.5692, 0.8489, 0.3872],
[0.3168, 0.5182, 0.5418, 0.6114, 0.3808],
[0.1931, 0.4954, 0.5303, 0.3934, 0.1458]]]), tensor([[[1, 0, 1, 3, 1],
[3, 2, 0, 2, 2],
[2, 3, 2, 1, 3]],
[[4, 2, 2, 2, 4],
[0, 4, 4, 1, 0],
[3, 3, 0, 3, 1]]]))
Top-k over edge attribute :attr:`h` along index -1 in a batched graph.
(used in SortPooling)
>>> dgl.topk_edges(bg, 'h', 3, idx=-1)
(tensor([[[0.5901, 0.3030, 0.9280, 0.6893, 0.7997],
[0.0880, 0.6515, 0.4451, 0.7507, 0.5297],
[0.5171, 0.6379, 0.2695, 0.8954, 0.5197]],
[[0.5065, 0.5182, 0.5418, 0.1520, 0.3872],
[0.3168, 0.3174, 0.5303, 0.0804, 0.3808],
[0.1323, 0.2766, 0.4318, 0.6114, 0.1458]]]), tensor([[1, 2, 3],
[4, 0, 1]]))
Top-k over edge attribute :attr:`h` in a single graph.
>>> dgl.topk_edges(g1, 'h', 3)
(tensor([[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],
[0.5171, 0.6515, 0.9140, 0.7507, 0.5297],
[0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]), tensor([[[1, 0, 1, 3, 1],
[3, 2, 0, 2, 2],
[2, 3, 2, 1, 3]]]))
Notes
-----
If an example has :math:`n` edges and :math:`n<k`, in the first
returned tensor the :math:`n+1` to :math:`k`th rows would be padded
with all zero; in the second returned tensor, the behavior of :math:`n+1`
to :math:`k`th elements is not defined.
"""
return _topk_on(graph, 'edges', feat, k, descending=descending, idx=idx)
"""Package for mxnet-specific NN modules.""" """Package for mxnet-specific NN modules."""
from .conv import * from .conv import *
from .glob import *
from .softmax import * from .softmax import *
"""MXNet modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, C0103, W0235
import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon import nn
from ... import BatchedDGLGraph
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'GlobalAttentionPooling', 'Set2Set']
class SumPooling(nn.Block):
r"""Apply sum pooling over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(SumPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute sum pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = sum_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'SumPooling()'
class AvgPooling(nn.Block):
r"""Apply average pooling over the nodes in the graph.
.. math::
r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(AvgPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute average pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = mean_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'AvgPooling()'
class MaxPooling(nn.Block):
r"""Apply max pooling over the nodes in the graph.
.. math::
r^{(i)} = \max_{k=1}^{N_i} \left( x^{(i)}_k \right)
"""
def __init__(self):
super(MaxPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute max pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = max_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'MaxPooling()'
class SortPooling(nn.Block):
r"""Apply Sort Pooling (`An End-to-End Deep Learning Architecture for Graph Classification
<https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__) over the nodes in the graph.
Parameters
----------
k : int
The number of nodes to hold for each graph.
"""
def __init__(self, k):
super(SortPooling, self).__init__()
self.k = k
def forward(self, feat, graph):
r"""Compute sort pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(k * D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, k * D)`.
"""
# Sort the feature of each node in ascending order.
with graph.local_scope():
feat = feat.sort(axis=-1)
graph.ndata['h'] = feat
# Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k)[0].reshape(
-1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph):
return ret
else:
return ret.squeeze(axis=0)
def __repr__(self):
return 'SortPooling(k={})'.format(self.k)
class GlobalAttentionPooling(nn.Block):
r"""Apply Global Attention Pooling (`Gated Graph Sequence Neural Networks
<https://arxiv.org/abs/1511.05493.pdf>`__) over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate}
\left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)
Parameters
----------
gate_nn : gluon.nn.Block
A neural network that computes attention scores for each feature.
feat_nn : gluon.nn.Block, optional
A neural network applied to each feature before combining them
with attention scores.
"""
def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__()
with self.name_scope():
self.gate_nn = gate_nn
self.feat_nn = feat_nn
self._reset_parameters()
def _reset_parameters(self):
self.gate_nn.initialize(mx.init.Xavier())
if self.feat_nn:
self.feat_nn.initialize(mx.init.Xavier())
def forward(self, feat, graph):
r"""Compute global attention pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate
gate = softmax_nodes(graph, 'gate')
graph.ndata['r'] = feat * gate
readout = sum_nodes(graph, 'r')
return readout
class Set2Set(nn.Block):
r"""Apply Set2Set (`Order Matters: Sequence to sequence for sets
<https://arxiv.org/pdf/1511.06391.pdf>`__) over the nodes in the graph.
For each individual graph in the batch, set2set computes
.. math::
q_t &= \mathrm{LSTM} (q^*_{t-1})
\alpha_{i,t} &= \mathrm{softmax}(x_i \cdot q_t)
r_t &= \sum_{i=1}^N \alpha_{i,t} x_i
q^*_t &= q_t \Vert r_t
for this graph.
Parameters
----------
input_dim : int
Size of each input sample
n_iters : int
Number of iterations.
n_layers : int
Number of recurrent layers.
"""
def __init__(self, input_dim, n_iters, n_layers):
super(Set2Set, self).__init__()
self.input_dim = input_dim
self.output_dim = 2 * input_dim
self.n_iters = n_iters
self.n_layers = n_layers
with self.name_scope():
self.lstm = gluon.rnn.LSTM(
self.input_dim, num_layers=n_layers, input_size=self.output_dim)
self._reset_parameters()
def _reset_parameters(self):
self.lstm.initialize(mx.init.Xavier())
def forward(self, feat, graph):
r"""Compute set2set pooling.
Parameters
----------
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
"""
with graph.local_scope():
batch_size = 1
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
h = (nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context),
nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context))
q_star = nd.zeros((batch_size, self.output_dim), ctx=feat.context)
for _ in range(self.n_iters):
q, h = self.lstm(q_star.expand_dims(axis=0), h)
q = q.reshape((batch_size, self.input_dim))
e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True)
graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r')
if readout.ndim == 1: # graph is not a BatchedDGLGraph
readout = readout.expand_dims(0)
q_star = nd.concat(q, readout, dim=-1)
if isinstance(graph, BatchedDGLGraph):
return q_star
else:
return q_star.squeeze(axis=0)
def __repr__(self):
summary = 'Set2Set('
summary += 'in={}, out={}, ' \
'n_iters={}, n_layers={}'.format(self.input_dim,
self.output_dim,
self.n_iters,
self.n_layers)
summary += ')'
return summary
"""Package for pytorch-specific NN modules.""" """Package for pytorch-specific NN modules."""
from .conv import * from .conv import *
from .glob import *
from .softmax import * from .softmax import *
"""Torch modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, C0103, W0235
import torch as th
import torch.nn as nn
import numpy as np
from ... import BatchedDGLGraph
from ...backend import pytorch as F
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'GlobalAttentionPooling', 'Set2Set',
'SetTransformerEncoder', 'SetTransformerDecoder']
class SumPooling(nn.Module):
r"""Apply sum pooling over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(SumPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute sum pooling.
Parameters
----------
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = sum_nodes(graph, 'h')
return readout
class AvgPooling(nn.Module):
r"""Apply average pooling over the nodes in the graph.
.. math::
r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(AvgPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute average pooling.
Parameters
----------
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = mean_nodes(graph, 'h')
return readout
class MaxPooling(nn.Module):
r"""Apply max pooling over the nodes in the graph.
.. math::
r^{(i)} = \max_{k=1}^{N_i}\left( x^{(i)}_k \right)
"""
def __init__(self):
super(MaxPooling, self).__init__()
def forward(self, feat, graph):
r"""Compute max pooling.
Parameters
----------
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = max_nodes(graph, 'h')
return readout
class SortPooling(nn.Module):
r"""Apply Sort Pooling (`An End-to-End Deep Learning Architecture for Graph Classification
<https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__) over the nodes in the graph.
Parameters
----------
k : int
The number of nodes to hold for each graph.
"""
def __init__(self, k):
super(SortPooling, self).__init__()
self.k = k
def forward(self, feat, graph):
r"""Compute sort pooling.
Parameters
----------
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature with shape :math:`(k * D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, k * D)`.
"""
with graph.local_scope():
# Sort the feature of each node in ascending order.
feat, _ = feat.sort(dim=-1)
graph.ndata['h'] = feat
# Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k, idx=-1)[0].view(
-1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph):
return ret
else:
return ret.squeeze(0)
class GlobalAttentionPooling(nn.Module):
r"""Apply Global Attention Pooling (`Gated Graph Sequence Neural Networks
<https://arxiv.org/abs/1511.05493.pdf>`__) over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate}
\left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)
Parameters
----------
gate_nn : torch.nn.Module
A neural network that computes attention scores for each feature.
feat_nn : torch.nn.Module, optional
A neural network applied to each feature before combining them
with attention scores.
"""
def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__()
self.gate_nn = gate_nn
self.feat_nn = feat_nn
self._reset_parameters()
def _reset_parameters(self):
for p in self.gate_nn.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
if self.feat_nn:
for p in self.feat_nn.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feat, graph):
r"""Compute global attention pooling.
Parameters
----------
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate
gate = softmax_nodes(graph, 'gate')
graph.ndata.pop('gate')
graph.ndata['r'] = feat * gate
readout = sum_nodes(graph, 'r')
graph.ndata.pop('r')
return readout
class Set2Set(nn.Module):
r"""Apply Set2Set (`Order Matters: Sequence to sequence for sets
<https://arxiv.org/pdf/1511.06391.pdf>`__) over the nodes in the graph.
For each individual graph in the batch, set2set computes
.. math::
q_t &= \mathrm{LSTM} (q^*_{t-1})
\alpha_{i,t} &= \mathrm{softmax}(x_i \cdot q_t)
r_t &= \sum_{i=1}^N \alpha_{i,t} x_i
q^*_t &= q_t \Vert r_t
for this graph.
Parameters
----------
input_dim : int
Size of each input sample
n_iters : int
Number of iterations.
n_layers : int
Number of recurrent layers.
"""
def __init__(self, input_dim, n_iters, n_layers):
super(Set2Set, self).__init__()
self.input_dim = input_dim
self.output_dim = 2 * input_dim
self.n_iters = n_iters
self.n_layers = n_layers
self.lstm = th.nn.LSTM(self.output_dim, self.input_dim, n_layers)
self._reset_parameters()
def _reset_parameters(self):
for p in self.lstm.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feat, graph):
r"""Compute set2set pooling.
Parameters
----------
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
"""
with graph.local_scope():
batch_size = 1
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
feat.new_zeros((self.n_layers, batch_size, self.input_dim)))
q_star = feat.new_zeros(batch_size, self.output_dim)
for _ in range(self.n_iters):
q, h = self.lstm(q_star.unsqueeze(0), h)
q = q.view(batch_size, self.input_dim)
e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True)
graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r')
if readout.dim() == 1: # graph is not a BatchedDGLGraph
readout = readout.unsqueeze(0)
q_star = th.cat([q, readout], dim=-1)
if isinstance(graph, BatchedDGLGraph):
return q_star
else:
return q_star.squeeze(0)
def extra_repr(self):
"""Set the extra representation of the module.
which will come into effect when printing the model.
"""
summary = 'n_iters={n_iters}'
return summary.format(**self.__dict__)
class MultiHeadAttention(nn.Module):
r"""Multi-Head Attention block, used in Transformer, Set Transformer and so on."""
def __init__(self, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_head
self.d_ff = d_ff
self.proj_q = nn.Linear(d_model, num_heads * d_head, bias=False)
self.proj_k = nn.Linear(d_model, num_heads * d_head, bias=False)
self.proj_v = nn.Linear(d_model, num_heads * d_head, bias=False)
self.proj_o = nn.Linear(num_heads * d_head, d_model, bias=False)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropouth),
nn.Linear(d_ff, d_model)
)
self.droph = nn.Dropout(dropouth)
self.dropa = nn.Dropout(dropouta)
self.norm_in = nn.LayerNorm(d_model)
self.norm_inter = nn.LayerNorm(d_model)
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, x, mem, lengths_x, lengths_mem):
"""
Compute multi-head self-attention.
Parameters
----------
x : torch.Tensor
The input tensor used to compute queries.
mem : torch.Tensor
The memory tensor used to compute keys and values.
lengths_x : list
The array of node numbers, used to segment x.
lengths_mem : list
The array of node numbers, used to segment mem.
"""
batch_size = len(lengths_x)
max_len_x = max(lengths_x)
max_len_mem = max(lengths_mem)
queries = self.proj_q(x).view(-1, self.num_heads, self.d_head)
keys = self.proj_k(mem).view(-1, self.num_heads, self.d_head)
values = self.proj_v(mem).view(-1, self.num_heads, self.d_head)
# padding to (B, max_len_x/mem, num_heads, d_head)
queries = F.pad_packed_tensor(queries, lengths_x, 0)
keys = F.pad_packed_tensor(keys, lengths_mem, 0)
values = F.pad_packed_tensor(values, lengths_mem, 0)
# attention score with shape (B, num_heads, max_len_x, max_len_mem)
e = th.einsum('bxhd,byhd->bhxy', queries, keys)
# normalize
e = e / np.sqrt(self.d_head)
# generate mask
mask = th.zeros(batch_size, max_len_x, max_len_mem).to(e.device)
for i in range(batch_size):
mask[i, :lengths_x[i], :lengths_mem[i]].fill_(1)
mask = mask.unsqueeze(1)
e.masked_fill_(mask == 0, -float('inf'))
# apply softmax
alpha = th.softmax(e, dim=-1)
# sum of value weighted by alpha
out = th.einsum('bhxy,byhd->bxhd', alpha, values)
# project to output
out = self.proj_o(
out.contiguous().view(batch_size, max_len_x, self.num_heads * self.d_head))
# pack tensor
out = F.pack_padded_tensor(out, lengths_x)
# intra norm
x = self.norm_in(x + out)
# inter norm
x = self.norm_inter(x + self.ffn(x))
return x
class SetAttentionBlock(nn.Module):
r"""SAB block mentioned in Set-Transformer paper."""
def __init__(self, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(SetAttentionBlock, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta)
def forward(self, feat, lengths):
"""
Compute a Set Attention Block.
Parameters
----------
feat : torch.Tensor
The input feature.
lengths : list
The array of node numbers, used to segment feat tensor.
"""
return self.mha(feat, feat, lengths, lengths)
class InducedSetAttentionBlock(nn.Module):
r"""ISAB block mentioned in Set-Transformer paper."""
def __init__(self, m, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(InducedSetAttentionBlock, self).__init__()
self.m = m
self.d_model = d_model
self.inducing_points = nn.Parameter(
th.FloatTensor(m, d_model)
)
self.mha = nn.ModuleList([
MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta) for _ in range(2)])
self._reset_parameters()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.inducing_points)
def forward(self, feat, lengths):
"""
Compute an Induced Set Attention Block.
Parameters
----------
feat : torch.Tensor
The input feature.
lengths : list
The array of node numbers, used to segment feat tensor.
Returns
-------
torch.Tensor
The output feature
"""
batch_size = len(lengths)
query = self.inducing_points.repeat(batch_size, 1)
memory = self.mha[0](query, feat, [self.m] * batch_size, lengths)
return self.mha[1](feat, memory, lengths, [self.m] * batch_size)
def extra_repr(self):
"""Set the extra representation of the module.
which will come into effect when printing the model.
"""
shape_str = '({}, {})'.format(self.inducing_points.shape[0], self.inducing_points.shape[1])
return 'InducedVector: ' + shape_str
class PMALayer(nn.Module):
r"""Pooling by Multihead Attention, used in the Decoder Module of Set Transformer."""
def __init__(self, k, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(PMALayer, self).__init__()
self.k = k
self.d_model = d_model
self.seed_vectors = nn.Parameter(
th.FloatTensor(k, d_model)
)
self.mha = MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropouth),
nn.Linear(d_ff, d_model)
)
self._reset_parameters()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.seed_vectors)
def forward(self, feat, lengths):
"""
Compute Pooling by Multihead Attention.
Parameters
----------
feat : torch.Tensor
The input feature.
lengths : list
The array of node numbers, used to segment feat tensor.
Returns
-------
torch.Tensor
The output feature
"""
batch_size = len(lengths)
query = self.seed_vectors.repeat(batch_size, 1)
return self.mha(query, self.ffn(feat), [self.k] * batch_size, lengths)
def extra_repr(self):
"""Set the extra representation of the module.
which will come into effect when printing the model.
"""
shape_str = '({}, {})'.format(self.seed_vectors.shape[0], self.seed_vectors.shape[1])
return 'SeedVector: ' + shape_str
class SetTransformerEncoder(nn.Module):
r"""The Encoder module in `Set Transformer: A Framework for Attention-based
Permutation-Invariant Neural Networks <https://arxiv.org/pdf/1810.00825.pdf>`__.
Parameters
----------
d_model : int
Hidden size of the model.
n_heads : int
Number of heads.
d_head : int
Hidden size of each head.
d_ff : int
Kernel size in FFN (Positionwise Feed-Forward Network) layer.
n_layers : int
Number of layers.
block_type : str
Building block type: 'sab' (Set Attention Block) or 'isab' (Induced
Set Attention Block).
m : int or None
Number of induced vectors in ISAB Block, set to None if block type
is 'sab'.
dropouth : float
Dropout rate of each sublayer.
dropouta : float
Dropout rate of attention heads.
"""
def __init__(self, d_model, n_heads, d_head, d_ff,
n_layers=1, block_type='sab', m=None, dropouth=0., dropouta=0.):
super(SetTransformerEncoder, self).__init__()
self.n_layers = n_layers
self.block_type = block_type
self.m = m
layers = []
if block_type == 'isab' and m is None:
raise KeyError('The number of inducing points is not specified in ISAB block.')
for _ in range(n_layers):
if block_type == 'sab':
layers.append(
SetAttentionBlock(d_model, n_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta))
elif block_type == 'isab':
layers.append(
InducedSetAttentionBlock(m, d_model, n_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta))
else:
raise KeyError("Unrecognized block type {}: we only support sab/isab")
self.layers = nn.ModuleList(layers)
def forward(self, feat, graph):
"""
Compute the Encoder part of Set Transformer.
Parameters
----------
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature with shape :math:`(N, D)`.
"""
if isinstance(graph, BatchedDGLGraph):
lengths = graph.batch_num_nodes
else:
lengths = [graph.number_of_nodes()]
for layer in self.layers:
feat = layer(feat, lengths)
return feat
class SetTransformerDecoder(nn.Module):
r"""The Decoder module in `Set Transformer: A Framework for Attention-based
Permutation-Invariant Neural Networks <https://arxiv.org/pdf/1810.00825.pdf>`__.
Parameters
----------
d_model : int
Hidden size of the model.
num_heads : int
Number of heads.
d_head : int
Hidden size of each head.
d_ff : int
Kernel size in FFN (Positionwise Feed-Forward Network) layer.
n_layers : int
Number of layers.
k : int
Number of seed vectors in PMA (Pooling by Multihead Attention) layer.
dropouth : float
Dropout rate of each sublayer.
dropouta : float
Dropout rate of attention heads.
"""
def __init__(self, d_model, num_heads, d_head, d_ff, n_layers, k, dropouth=0., dropouta=0.):
super(SetTransformerDecoder, self).__init__()
self.n_layers = n_layers
self.k = k
self.d_model = d_model
self.pma = PMALayer(k, d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta)
layers = []
for _ in range(n_layers):
layers.append(
SetAttentionBlock(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta))
self.layers = nn.ModuleList(layers)
def forward(self, feat, graph):
"""
Compute the decoder part of Set Transformer.
Parameters
----------
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
"""
if isinstance(graph, BatchedDGLGraph):
len_pma = graph.batch_num_nodes
len_sab = [self.k] * graph.batch_size
else:
len_pma = [graph.number_of_nodes()]
len_sab = [self.k]
feat = self.pma(feat, len_pma)
for layer in self.layers:
feat = layer(feat, len_sab)
if isinstance(graph, BatchedDGLGraph):
return feat.view(graph.batch_size, self.k * self.d_model)
else:
return feat.view(self.k * self.d_model)
import dgl import dgl
import backend as F import backend as F
import networkx as nx
def test_simple_readout(): def test_simple_readout():
g1 = dgl.DGLGraph() g1 = dgl.DGLGraph()
...@@ -57,8 +58,192 @@ def test_simple_readout(): ...@@ -57,8 +58,192 @@ def test_simple_readout():
max_bg_e = dgl.max_edges(g, 'x') max_bg_e = dgl.max_edges(g, 'x')
assert F.allclose(s, F.stack([se1, F.zeros(5)], 0)) assert F.allclose(s, F.stack([se1, F.zeros(5)], 0))
assert F.allclose(m, F.stack([me1, F.zeros(5)], 0)) assert F.allclose(m, F.stack([me1, F.zeros(5)], 0))
assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0)) # TODO(zihao): fix -inf issue
# assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
def test_topk_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(14))
feat0 = F.randn((g0.number_of_nodes(), 10))
g0.ndata['x'] = feat0
# to test the case where k > number of nodes.
dgl.topk_nodes(g0, 'x', 20, idx=-1)
# test correctness
val, indices = dgl.topk_nodes(g0, 'x', 5, idx=-1)
ground_truth = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 9, 10), 0, True)[:5], (5,))
assert F.allclose(ground_truth, indices)
g0.ndata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(12))
feat1 = F.randn((g1.number_of_nodes(), 10))
bg = dgl.batch([g0, g1])
bg.ndata['x'] = F.cat([feat0, feat1], 0)
# to test the case where k > number of nodes.
dgl.topk_nodes(bg, 'x', 16, idx=1)
# test correctness
val, indices = dgl.topk_nodes(bg, 'x', 6, descending=False, idx=0)
ground_truth_0 = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 0, 1), 0, False)[:6], (6,))
ground_truth_1 = F.reshape(
F.argsort(F.slice_axis(feat1, -1, 0, 1), 0, False)[:6], (6,))
ground_truth = F.stack([ground_truth_0, ground_truth_1], 0)
assert F.allclose(ground_truth, indices)
# test idx=None
val, indices = dgl.topk_nodes(bg, 'x', 6, descending=True)
assert F.allclose(val, F.stack([F.topk(feat0, 6, 0), F.topk(feat1, 6, 0)], 0))
def test_topk_edges():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(14))
feat0 = F.randn((g0.number_of_edges(), 10))
g0.edata['x'] = feat0
# to test the case where k > number of edges.
dgl.topk_edges(g0, 'x', 30, idx=-1)
# test correctness
val, indices = dgl.topk_edges(g0, 'x', 7, idx=-1)
ground_truth = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 9, 10), 0, True)[:7], (7,))
assert F.allclose(ground_truth, indices)
g0.edata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(12))
feat1 = F.randn((g1.number_of_edges(), 10))
bg = dgl.batch([g0, g1])
bg.edata['x'] = F.cat([feat0, feat1], 0)
# to test the case where k > number of edges.
dgl.topk_edges(bg, 'x', 33, idx=1)
# test correctness
val, indices = dgl.topk_edges(bg, 'x', 4, descending=False, idx=0)
ground_truth_0 = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 0, 1), 0, False)[:4], (4,))
ground_truth_1 = F.reshape(
F.argsort(F.slice_axis(feat1, -1, 0, 1), 0, False)[:4], (4,))
ground_truth = F.stack([ground_truth_0, ground_truth_1], 0)
assert F.allclose(ground_truth, indices)
# test idx=None
val, indices = dgl.topk_edges(bg, 'x', 6, descending=True)
assert F.allclose(val, F.stack([F.topk(feat0, 6, 0), F.topk(feat1, 6, 0)], 0))
def test_softmax_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(9))
feat0 = F.randn((g0.number_of_nodes(), 10))
g0.ndata['x'] = feat0
ground_truth = F.softmax(feat0, dim=0)
assert F.allclose(dgl.softmax_nodes(g0, 'x'), ground_truth)
g0.ndata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5))
g2 = dgl.DGLGraph(nx.path_graph(3))
g3 = dgl.DGLGraph()
g4 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g0, g1, g2, g3, g4])
feat1 = F.randn((g1.number_of_nodes(), 10))
feat2 = F.randn((g2.number_of_nodes(), 10))
feat4 = F.randn((g4.number_of_nodes(), 10))
bg.ndata['x'] = F.cat([feat0, feat1, feat2, feat4], 0)
ground_truth = F.cat([
F.softmax(feat0, 0),
F.softmax(feat1, 0),
F.softmax(feat2, 0),
F.softmax(feat4, 0)
], 0)
assert F.allclose(dgl.softmax_nodes(bg, 'x'), ground_truth)
def test_softmax_edges():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((g0.number_of_edges(), 10))
g0.edata['x'] = feat0
ground_truth = F.softmax(feat0, dim=0)
assert F.allclose(dgl.softmax_edges(g0, 'x'), ground_truth)
g0.edata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5))
g2 = dgl.DGLGraph(nx.path_graph(3))
g3 = dgl.DGLGraph()
g4 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g0, g1, g2, g3, g4])
feat1 = F.randn((g1.number_of_edges(), 10))
feat2 = F.randn((g2.number_of_edges(), 10))
feat4 = F.randn((g4.number_of_edges(), 10))
bg.edata['x'] = F.cat([feat0, feat1, feat2, feat4], 0)
ground_truth = F.cat([
F.softmax(feat0, 0),
F.softmax(feat1, 0),
F.softmax(feat2, 0),
F.softmax(feat4, 0)
], 0)
assert F.allclose(dgl.softmax_edges(bg, 'x'), ground_truth)
def test_broadcast_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,))
ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0)
assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth)
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(3))
g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,))
feat2 = F.randn((40,))
feat3 = F.randn((40,))
ground_truth = F.stack(
[feat0] * g0.number_of_nodes() +\
[feat1] * g1.number_of_nodes() +\
[feat2] * g2.number_of_nodes() +\
[feat3] * g3.number_of_nodes(), 0
)
assert F.allclose(dgl.broadcast_nodes(
bg, F.stack([feat0, feat1, feat2, feat3], 0)
), ground_truth)
def test_broadcast_edges():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,))
ground_truth = F.stack([feat0] * g0.number_of_edges(), 0)
assert F.allclose(dgl.broadcast_edges(g0, feat0), ground_truth)
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(3))
g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,))
feat2 = F.randn((40,))
feat3 = F.randn((40,))
ground_truth = F.stack(
[feat0] * g0.number_of_edges() +\
[feat1] * g1.number_of_edges() +\
[feat2] * g2.number_of_edges() +\
[feat3] * g3.number_of_edges(), 0
)
assert F.allclose(dgl.broadcast_edges(
bg, F.stack([feat0, feat1, feat2, feat3], 0)
), ground_truth)
if __name__ == '__main__': if __name__ == '__main__':
test_simple_readout() test_simple_readout()
test_topk_nodes()
test_topk_edges()
test_softmax_nodes()
test_softmax_edges()
test_broadcast_nodes()
test_broadcast_edges()
...@@ -3,11 +3,10 @@ import networkx as nx ...@@ -3,11 +3,10 @@ import networkx as nx
import numpy as np import numpy as np
import dgl import dgl
import dgl.nn.mxnet as nn import dgl.nn.mxnet as nn
from mxnet import autograd from mxnet import autograd, gluon
def check_eq(a, b): def check_close(a, b):
assert a.shape == b.shape assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)
assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
def _AXWb(A, X, W, b): def _AXWb(A, X, W, b):
X = mx.nd.dot(X, W.data(X.context)) X = mx.nd.dot(X, W.data(X.context))
...@@ -26,13 +25,13 @@ def test_graph_conv(): ...@@ -26,13 +25,13 @@ def test_graph_conv():
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias)) check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim # test#2: more-dim
h0 = mx.nd.ones((3, 5, 5)) h0 = mx.nd.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias)) check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
...@@ -69,7 +68,92 @@ def test_graph_conv(): ...@@ -69,7 +68,92 @@ def test_graph_conv():
assert len(g.ndata) == 1 assert len(g.ndata) == 1
assert len(g.edata) == 0 assert len(g.edata) == 0
assert "h" in g.ndata assert "h" in g.ndata
check_eq(g.ndata['h'], 2 * mx.nd.ones((3, 1))) check_close(g.ndata['h'], 2 * mx.nd.ones((3, 1)))
def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
print(s2s)
# test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5)
h1 = s2s(h0, g)
assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph
bg = dgl.batch([g, g, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5)
h1 = s2s(h0, bg)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2
def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
print(gap)
# test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5)
h1 = gap(h0, g)
assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5)
h1 = gap(h0, bg)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
def test_simple_pool():
g = dgl.DGLGraph(nx.path_graph(15))
sum_pool = nn.SumPooling()
avg_pool = nn.AvgPooling()
max_pool = nn.MaxPooling()
sort_pool = nn.SortPooling(10) # k = 10
print(sum_pool, avg_pool, max_pool, sort_pool)
# test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5)
h1 = sum_pool(h0, g)
check_close(h1, mx.nd.sum(h0, 0))
h1 = avg_pool(h0, g)
check_close(h1, mx.nd.mean(h0, 0))
h1 = max_pool(h0, g)
check_close(h1, mx.nd.max(h0, 0))
h1 = sort_pool(h0, g)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5)
h1 = sum_pool(h0, bg)
truth = mx.nd.stack(mx.nd.sum(h0[:15], 0),
mx.nd.sum(h0[15:20], 0),
mx.nd.sum(h0[20:35], 0),
mx.nd.sum(h0[35:40], 0),
mx.nd.sum(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = avg_pool(h0, bg)
truth = mx.nd.stack(mx.nd.mean(h0[:15], 0),
mx.nd.mean(h0[15:20], 0),
mx.nd.mean(h0[20:35], 0),
mx.nd.mean(h0[35:40], 0),
mx.nd.mean(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = max_pool(h0, bg)
truth = mx.nd.stack(mx.nd.max(h0[:15], 0),
mx.nd.max(h0[15:20], 0),
mx.nd.max(h0[20:35], 0),
mx.nd.max(h0[35:40], 0),
mx.nd.max(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = sort_pool(h0, bg)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2
def uniform_attention(g, shape): def uniform_attention(g, shape):
a = mx.nd.ones(shape) a = mx.nd.ones(shape)
...@@ -97,3 +181,6 @@ def test_edge_softmax(): ...@@ -97,3 +181,6 @@ def test_edge_softmax():
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_edge_softmax() test_edge_softmax()
test_set2set()
test_glob_att_pool()
test_simple_pool()
...@@ -61,6 +61,124 @@ def test_graph_conv(): ...@@ -61,6 +61,124 @@ def test_graph_conv():
new_weight = conv.weight.data new_weight = conv.weight.data
assert not th.allclose(old_weight, new_weight) assert not th.allclose(old_weight, new_weight)
def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
print(s2s)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 5)
h1 = s2s(h0, g)
assert h1.shape[0] == 10 and h1.dim() == 1
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(11))
g2 = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g1, g2])
h0 = th.rand(bg.number_of_nodes(), 5)
h1 = s2s(h0, bg)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2
def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
print(gap)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 5)
h1 = gap(h0, g)
assert h1.shape[0] == 10 and h1.dim() == 1
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
h0 = th.rand(bg.number_of_nodes(), 5)
h1 = gap(h0, bg)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2
def test_simple_pool():
g = dgl.DGLGraph(nx.path_graph(15))
sum_pool = nn.SumPooling()
avg_pool = nn.AvgPooling()
max_pool = nn.MaxPooling()
sort_pool = nn.SortPooling(10) # k = 10
print(sum_pool, avg_pool, max_pool, sort_pool)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 5)
h1 = sum_pool(h0, g)
assert th.allclose(h1, th.sum(h0, 0))
h1 = avg_pool(h0, g)
assert th.allclose(h1, th.mean(h0, 0))
h1 = max_pool(h0, g)
assert th.allclose(h1, th.max(h0, 0)[0])
h1 = sort_pool(h0, g)
assert h1.shape[0] == 10 * 5 and h1.dim() == 1
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g])
h0 = th.rand(bg.number_of_nodes(), 5)
h1 = sum_pool(h0, bg)
truth = th.stack([th.sum(h0[:15], 0),
th.sum(h0[15:20], 0),
th.sum(h0[20:35], 0),
th.sum(h0[35:40], 0),
th.sum(h0[40:55], 0)], 0)
assert th.allclose(h1, truth)
h1 = avg_pool(h0, bg)
truth = th.stack([th.mean(h0[:15], 0),
th.mean(h0[15:20], 0),
th.mean(h0[20:35], 0),
th.mean(h0[35:40], 0),
th.mean(h0[40:55], 0)], 0)
assert th.allclose(h1, truth)
h1 = max_pool(h0, bg)
truth = th.stack([th.max(h0[:15], 0)[0],
th.max(h0[15:20], 0)[0],
th.max(h0[20:35], 0)[0],
th.max(h0[35:40], 0)[0],
th.max(h0[40:55], 0)[0]], 0)
assert th.allclose(h1, truth)
h1 = sort_pool(h0, bg)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2
def test_set_trans():
g = dgl.DGLGraph(nx.path_graph(15))
st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
print(st_enc_0, st_enc_1, st_dec)
# test#1: basic
h0 = th.rand(g.number_of_nodes(), 50)
h1 = st_enc_0(h0, g)
assert h1.shape == h0.shape
h1 = st_enc_1(h0, g)
assert h1.shape == h0.shape
h2 = st_dec(h1, g)
assert h2.shape[0] == 200 and h2.dim() == 1
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5))
g2 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g, g1, g2])
h0 = th.rand(bg.number_of_nodes(), 50)
h1 = st_enc_0(h0, bg)
assert h1.shape == h0.shape
h1 = st_enc_1(h0, bg)
assert h1.shape == h0.shape
h2 = st_dec(h1, bg)
assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2
def uniform_attention(g, shape): def uniform_attention(g, shape):
a = th.ones(shape) a = th.ones(shape)
target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1) target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
...@@ -130,3 +248,7 @@ def test_edge_softmax(): ...@@ -130,3 +248,7 @@ def test_edge_softmax():
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_edge_softmax() test_edge_softmax()
test_set2set()
test_glob_att_pool()
test_simple_pool()
test_set_trans()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment