Unverified Commit 2fa7f71a authored by zhjwy9343's avatar zhjwy9343 Committed by GitHub
Browse files

[Doc] new nn api doc (#2019)



* Add dotproduct attention

* [Feature] Add dotproduct attention

* [Feature] Add dotproduct attention

* [Feature] Add dotproduct attention

* [New] Update landing page

* [New] Update landing page

* [New] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Improvement] use dgl build-in in dotgatconv

* [Doc] review API doc string bottom up

* [Doc] Add doc of input and output features

* [Doc] Update doc string for pooling and transformer.

* [Doc] Reformat doc string and change some wordings.

* [Doc] Doc string refactoring.

* [Doc] Doc string refactoring.

* [Doc] Doc string refactoring.

* [Doc] Doc string refactoring.

* [Doc] Doc string refactoring.
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent f4f78803
...@@ -181,12 +181,8 @@ class DotGatConv(nn.Module): ...@@ -181,12 +181,8 @@ class DotGatConv(nn.Module):
# Step 2. edge softmax to compute attention scores # Step 2. edge softmax to compute attention scores
graph.edata['sa'] = edge_softmax(graph, graph.edata['a']) graph.edata['sa'] = edge_softmax(graph, graph.edata['a'])
# Step 3. Broadcast softmax value to each edge, and then attention is done # Step 3. Broadcast softmax value to each edge, and aggregate dst node
graph.apply_edges(lambda edges: {'attn': edges.src['ft'] * \ graph.update_all(fn.u_mul_e('ft', 'sa', 'attn'), fn.sum('attn', 'agg_u'))
edges.data['sa'].unsqueeze(dim=0).T})
# Step 4. Aggregate attention to dst,user nodes, so formula 7 is done
graph.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_u'))
# output results to the destination nodes # output results to the destination nodes
rst = graph.dstdata['agg_u'] rst = graph.dstdata['agg_u']
......
...@@ -12,11 +12,23 @@ def pairwise_squared_distance(x): ...@@ -12,11 +12,23 @@ def pairwise_squared_distance(x):
class KNNGraph(nn.Module): class KNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of r"""
Description
-----------
Layer that transforms one point set into a graph, or a batch of
point sets with the same number of points into a union of those graphs. point sets with the same number of points into a union of those graphs.
If a batch of point set is provided, then the point :math:`j` in point The KNNGraph is implemented in the following steps:
set :math:`i` is mapped to graph node ID :math:`i \times M + j`, where
1. Compute an NxN matrix of pairwise distance for all points.
2. Pick the k points with the smallest distance for each point as their k-nearest neighbors.
3. Construct a graph with edges to each point as a node from its k-nearest neighbors.
The overall computational complexity is :math:`O(N^2(logN + D)`.
If a batch of point sets is provided, the point :math:`j` in point
set :math:`i` is mapped to graph node ID: :math:`i \times M + j`, where
:math:`M` is the number of nodes in each point set. :math:`M` is the number of nodes in each point set.
The predecessors of each node are the k-nearest neighbors of the The predecessors of each node are the k-nearest neighbors of the
...@@ -25,7 +37,30 @@ class KNNGraph(nn.Module): ...@@ -25,7 +37,30 @@ class KNNGraph(nn.Module):
Parameters Parameters
---------- ----------
k : int k : int
The number of neighbors The number of neighbors.
Notes
-----
The nearest neighbors found for a node include the node itself.
Examples
--------
The following example uses PyTorch backend.
>>> import torch
>>> from dgl.nn.pytorch.factory import KNNGraph
>>>
>>> kg = KNNGraph(2)
>>> x = torch.tensor([[0,1],
[1,2],
[1,3],
[100, 101],
[101, 102],
[50, 50]])
>>> g = kg(x)
>>> print(g.edges())
(tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]),
tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5]))
""" """
def __init__(self, k): def __init__(self, k):
super(KNNGraph, self).__init__() super(KNNGraph, self).__init__()
...@@ -33,7 +68,9 @@ class KNNGraph(nn.Module): ...@@ -33,7 +68,9 @@ class KNNGraph(nn.Module):
#pylint: disable=invalid-name #pylint: disable=invalid-name
def forward(self, x): def forward(self, x):
"""Forward computation. """
Forward computation.
Parameters Parameters
---------- ----------
...@@ -45,19 +82,23 @@ class KNNGraph(nn.Module): ...@@ -45,19 +82,23 @@ class KNNGraph(nn.Module):
Returns Returns
------- -------
DGLGraph DGLGraph
A DGLGraph with no features. A DGLGraph without features.
""" """
return knn_graph(x, self.k) return knn_graph(x, self.k)
class SegmentedKNNGraph(nn.Module): class SegmentedKNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of r"""
Description
-----------
Layer that transforms one point set into a graph, or a batch of
point sets with different number of points into a union of those graphs. point sets with different number of points into a union of those graphs.
If a batch of point set is provided, then the point :math:`j` in point If a batch of point sets is provided, then the point :math:`j` in the point
set :math:`i` is mapped to graph node ID set :math:`i` is mapped to graph node ID:
:math:`\sum_{p<i} |V_p| + j`, where :math:`|V_p|` means the number of :math:`\sum_{p<i} |V_p| + j`, where :math:`|V_p|` means the number of
points in point set :math:`p`. points in the point set :math:`p`.
The predecessors of each node are the k-nearest neighbors of the The predecessors of each node are the k-nearest neighbors of the
corresponding point. corresponding point.
...@@ -65,7 +106,34 @@ class SegmentedKNNGraph(nn.Module): ...@@ -65,7 +106,34 @@ class SegmentedKNNGraph(nn.Module):
Parameters Parameters
---------- ----------
k : int k : int
The number of neighbors The number of neighbors.
Notes
-----
The nearest neighbors found for a node include the node itself.
Examples
--------
The following example uses PyTorch backend.
>>> import torch
>>> from dgl.nn.pytorch.factory import SegmentedKNNGraph
>>>
>>> kg = SegmentedKNNGraph(2)
>>> x = torch.tensor([[0,1],
... [1,2],
... [1,3],
... [100, 101],
... [101, 102],
... [50, 50],
... [24,25],
... [25,24]])
>>> g = kg(x, [3,3,2])
>>> print(g.edges())
(tensor([0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 6, 6, 7, 7]),
tensor([0, 0, 1, 2, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 6, 7]))
>>>
""" """
def __init__(self, k): def __init__(self, k):
super(SegmentedKNNGraph, self).__init__() super(SegmentedKNNGraph, self).__init__()
...@@ -73,20 +141,22 @@ class SegmentedKNNGraph(nn.Module): ...@@ -73,20 +141,22 @@ class SegmentedKNNGraph(nn.Module):
#pylint: disable=invalid-name #pylint: disable=invalid-name
def forward(self, x, segs): def forward(self, x, segs):
"""Forward computation. r"""Forward computation.
Parameters Parameters
---------- ----------
x : Tensor x : Tensor
:math:`(M, D)` where :math:`M` means the total number of points :math:`(M, D)` where :math:`M` means the total number of points
in all point sets. in all point sets, and :math:`D` means the size of features.
segs : iterable of int segs : iterable of int
:math:`(N)` integers where :math:`N` means the number of point :math:`(N)` integers where :math:`N` means the number of point
sets. The elements must sum up to :math:`M`. sets. The number of elements must sum up to :math:`M`. And any
:math:`N` should :math:`\ge k`
Returns Returns
------- -------
DGLGraph DGLGraph
A DGLGraph with no features. A DGLGraph without features.
""" """
return segmented_knn_graph(x, self.k, segs) return segmented_knn_graph(x, self.k, segs)
...@@ -14,31 +14,76 @@ __all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling', ...@@ -14,31 +14,76 @@ __all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'SetTransformerEncoder', 'SetTransformerDecoder', 'WeightAndSum'] 'SetTransformerEncoder', 'SetTransformerDecoder', 'WeightAndSum']
class SumPooling(nn.Module): class SumPooling(nn.Module):
r"""Apply sum pooling over the nodes in the graph. r"""
Description
-----------
Apply sum pooling over the nodes in a graph .
.. math:: .. math::
r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k
Notes
-----
Input: Could be one graph, or a batch of graphs. If using a batch of graphs,
make sure nodes in all graphs have the same feature size, and concatenate
nodes' feature together as the input.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch as th
>>> from dgl.nn.pytorch.glob import SumPooling
>>>
>>> g1 = dgl.DGLGraph()
>>> g1.add_nodes(2)
>>> g1_node_feats = th.ones(2,5)
>>>
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3)
>>> g2_node_feats = th.ones(3,5)
>>>
>>> sumpool = SumPooling()
Case 1: Input a single graph
>>> sumpool(g1, g1_node_feats)
tensor([[2., 2., 2., 2., 2.]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.
>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> sumpool(batch_g, batch_f)
tensor([[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.]])
""" """
def __init__(self): def __init__(self):
super(SumPooling, self).__init__() super(SumPooling, self).__init__()
def forward(self, graph, feat): def forward(self, graph, feat):
r"""Compute sum pooling. r"""
Compute sum pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. a DGLGraph or a batch of DGLGraphs
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, D)`, where :math:`N` is the number
:math:`N` is the number of nodes in the graph. of nodes in the graph, and :math:`D` means the size of features.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(B, *)`, where The output feature with shape :math:`(B, D)`, where :math:`B` refers to the
:math:`B` refers to the batch size. batch size of input graphs.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -47,30 +92,76 @@ class SumPooling(nn.Module): ...@@ -47,30 +92,76 @@ class SumPooling(nn.Module):
class AvgPooling(nn.Module): class AvgPooling(nn.Module):
r"""Apply average pooling over the nodes in the graph. r"""
Description
-----------
Apply average pooling over the nodes in a graph.
.. math:: .. math::
r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k
Notes
-----
Input: Could be one graph, or a batch of graphs. If using a batch of graphs,
make sure nodes in all graphs have the same feature size, and concatenate
nodes' feature together as the input.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch as th
>>> from dgl.nn.pytorch.glob import AvgPooling
>>>
>>> g1 = dgl.DGLGraph()
>>> g1.add_nodes(2)
>>> g1_node_feats = th.ones(2,5)
>>>
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3)
>>> g2_node_feats = th.ones(3,5)
>>>
>>> avgpool = AvgPooling()
Case 1: Input single graph
>>> avgpool(g1, g1_node_feats)
tensor([[1., 1., 1., 1., 1.]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs' note features into one tensor.
>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> avgpool(batch_g, batch_f)
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
""" """
def __init__(self): def __init__(self):
super(AvgPooling, self).__init__() super(AvgPooling, self).__init__()
def forward(self, graph, feat): def forward(self, graph, feat):
r"""Compute average pooling. r"""
Compute average pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. A DGLGraph or a batch of DGLGraphs.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, D)`, where :math:`N` is the number
:math:`N` is the number of nodes in the graph. of nodes in the graph, and :math:`D` means the size of features.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(B, *)`, where The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size. :math:`B` refers to the batch size of input graphs.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -79,10 +170,54 @@ class AvgPooling(nn.Module): ...@@ -79,10 +170,54 @@ class AvgPooling(nn.Module):
class MaxPooling(nn.Module): class MaxPooling(nn.Module):
r"""Apply max pooling over the nodes in the graph. r"""
Description
-----------
Apply max pooling over the nodes in a graph.
.. math:: .. math::
r^{(i)} = \max_{k=1}^{N_i}\left( x^{(i)}_k \right) r^{(i)} = \max_{k=1}^{N_i}\left( x^{(i)}_k \right)
Notes
-----
Input: Could be one graph, or a batch of graphs. If using a batch of graphs,
make sure nodes in all graphs have the same feature size, and concatenate
nodes' feature together as the input.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch as th
>>> from dgl.nn.pytorch.glob import MaxPooling
>>>
>>> g1 = dgl.DGLGraph()
>>> g1.add_nodes(2)
>>> g1_node_feats = th.ones(2,5)
>>>
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3)
>>> g2_node_feats = th.ones(3,5)
>>>
>>> maxpool = MaxPooling()
Case 1: Input a single graph
>>> maxpool(g1, g1_node_feats)
tensor([[1., 1., 1., 1., 1.]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.
>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> maxpool(batch_g, batch_f)
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
""" """
def __init__(self): def __init__(self):
super(MaxPooling, self).__init__() super(MaxPooling, self).__init__()
...@@ -93,9 +228,9 @@ class MaxPooling(nn.Module): ...@@ -93,9 +228,9 @@ class MaxPooling(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. A DGLGraph or a batch of DGLGraphs.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)`, where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
Returns Returns
...@@ -111,34 +246,79 @@ class MaxPooling(nn.Module): ...@@ -111,34 +246,79 @@ class MaxPooling(nn.Module):
class SortPooling(nn.Module): class SortPooling(nn.Module):
r"""Apply Sort Pooling (`An End-to-End Deep Learning Architecture for Graph Classification r"""
<https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__) over the nodes in the graph.
Description
-----------
Apply Sort Pooling (`An End-to-End Deep Learning Architecture for Graph Classification
<https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__) over the nodes in a graph.
Parameters Parameters
---------- ----------
k : int k : int
The number of nodes to hold for each graph. The number of nodes to hold for each graph.
Notes
-----
Input: Could be one graph, or a batch of graphs. If using a batch of graphs,
make sure nodes in all graphs have the same feature size, and concatenate
nodes' feature together as the input.
Examples
--------
>>> import dgl
>>> import torch as th
>>> from dgl.nn.pytorch.glob import SortPooling
>>>
>>> g1 = dgl.DGLGraph()
>>> g1.add_nodes(2)
>>> g1_node_feats = th.ones(2,5)
>>>
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3)
>>> g2_node_feats = th.ones(3,5)
>>>
>>> sortpool = SortPooling(k=2)
Case 1: Input a single graph
>>> sortpool(g1, g1_node_feats)
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.
>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> sortpool(batch_g, batch_f)
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
""" """
def __init__(self, k): def __init__(self, k):
super(SortPooling, self).__init__() super(SortPooling, self).__init__()
self.k = k self.k = k
def forward(self, graph, feat): def forward(self, graph, feat):
r"""Compute sort pooling. r"""
Compute sort pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. A DGLGraph or a batch of DGLGraphs.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)`, where :math:`N` is the
:math:`N` is the number of nodes in the graph. number of nodes in the graph, and :math:`D` means the size of features.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(B, k * D)`, where The output feature with shape :math:`(B, k * D)`, where :math:`B` refers
:math:`B` refers to the batch size. to the batch size of input graphs.
""" """
with graph.local_scope(): with graph.local_scope():
# Sort the feature of each node in ascending order. # Sort the feature of each node in ascending order.
...@@ -151,8 +331,12 @@ class SortPooling(nn.Module): ...@@ -151,8 +331,12 @@ class SortPooling(nn.Module):
class GlobalAttentionPooling(nn.Module): class GlobalAttentionPooling(nn.Module):
r"""Apply Global Attention Pooling (`Gated Graph Sequence Neural Networks r"""
<https://arxiv.org/abs/1511.05493.pdf>`__) over the nodes in the graph.
Description
-----------
Apply Global Attention Pooling (`Gated Graph Sequence Neural Networks
<https://arxiv.org/abs/1511.05493.pdf>`__) over the nodes in a graph.
.. math:: .. math::
r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate} r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate}
...@@ -163,8 +347,8 @@ class GlobalAttentionPooling(nn.Module): ...@@ -163,8 +347,8 @@ class GlobalAttentionPooling(nn.Module):
gate_nn : torch.nn.Module gate_nn : torch.nn.Module
A neural network that computes attention scores for each feature. A neural network that computes attention scores for each feature.
feat_nn : torch.nn.Module, optional feat_nn : torch.nn.Module, optional
A neural network applied to each feature before combining them A neural network applied to each feature before combining them with attention
with attention scores. scores.
""" """
def __init__(self, gate_nn, feat_nn=None): def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__() super(GlobalAttentionPooling, self).__init__()
...@@ -172,21 +356,23 @@ class GlobalAttentionPooling(nn.Module): ...@@ -172,21 +356,23 @@ class GlobalAttentionPooling(nn.Module):
self.feat_nn = feat_nn self.feat_nn = feat_nn
def forward(self, graph, feat): def forward(self, graph, feat):
r"""Compute global attention pooling. r"""
Compute global attention pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. A DGLGraph or a batch of DGLGraphs.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where :math:`N` is the
:math:`N` is the number of nodes in the graph. number of nodes in the graph, and :math:`D` means the size of features.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(B, D)`, where The output feature with shape :math:`(B, D)`, where :math:`B` refers
:math:`B` refers to the batch size. to the batch size.
""" """
with graph.local_scope(): with graph.local_scope():
gate = self.gate_nn(feat) gate = self.gate_nn(feat)
...@@ -205,9 +391,10 @@ class GlobalAttentionPooling(nn.Module): ...@@ -205,9 +391,10 @@ class GlobalAttentionPooling(nn.Module):
class Set2Set(nn.Module): class Set2Set(nn.Module):
r"""Apply Set2Set (`Order Matters: Sequence to sequence for sets r"""
<https://arxiv.org/pdf/1511.06391.pdf>`__) over the nodes in the graph.
Description
-----------
For each individual graph in the batch, set2set computes For each individual graph in the batch, set2set computes
.. math:: .. math::
...@@ -224,11 +411,11 @@ class Set2Set(nn.Module): ...@@ -224,11 +411,11 @@ class Set2Set(nn.Module):
Parameters Parameters
---------- ----------
input_dim : int input_dim : int
Size of each input sample The size of each input sample.
n_iters : int n_iters : int
Number of iterations. The number of iterations.
n_layers : int n_layers : int
Number of recurrent layers. The number of recurrent layers.
""" """
def __init__(self, input_dim, n_iters, n_layers): def __init__(self, input_dim, n_iters, n_layers):
super(Set2Set, self).__init__() super(Set2Set, self).__init__()
...@@ -244,21 +431,22 @@ class Set2Set(nn.Module): ...@@ -244,21 +431,22 @@ class Set2Set(nn.Module):
self.lstm.reset_parameters() self.lstm.reset_parameters()
def forward(self, graph, feat): def forward(self, graph, feat):
r"""Compute set2set pooling. r"""
Compute set2set pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The input graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where :math:`N` is the
:math:`N` is the number of nodes in the graph. number of nodes in the graph, and :math:`D` means the size of features.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(B, D)`, where The output feature with shape :math:`(B, D)`, where :math:`B` refers to
:math:`B` refers to the batch size. the batch size, and :math:`D` means the size of features.
""" """
with graph.local_scope(): with graph.local_scope():
batch_size = graph.batch_size batch_size = graph.batch_size
...@@ -497,31 +685,35 @@ class PMALayer(nn.Module): ...@@ -497,31 +685,35 @@ class PMALayer(nn.Module):
class SetTransformerEncoder(nn.Module): class SetTransformerEncoder(nn.Module):
r"""The Encoder module in `Set Transformer: A Framework for Attention-based r"""
Description
-----------
The Encoder module in `Set Transformer: A Framework for Attention-based
Permutation-Invariant Neural Networks <https://arxiv.org/pdf/1810.00825.pdf>`__. Permutation-Invariant Neural Networks <https://arxiv.org/pdf/1810.00825.pdf>`__.
Parameters Parameters
---------- ----------
d_model : int d_model : int
Hidden size of the model. The hidden size of the model.
n_heads : int n_heads : int
Number of heads. The number of heads.
d_head : int d_head : int
Hidden size of each head. The hidden size of each head.
d_ff : int d_ff : int
Kernel size in FFN (Positionwise Feed-Forward Network) layer. The kernel size in FFN (Positionwise Feed-Forward Network) layer.
n_layers : int n_layers : int
Number of layers. The number of layers.
block_type : str block_type : str
Building block type: 'sab' (Set Attention Block) or 'isab' (Induced Building block type: 'sab' (Set Attention Block) or 'isab' (Induced
Set Attention Block). Set Attention Block).
m : int or None m : int or None
Number of induced vectors in ISAB Block, set to None if block type The number of induced vectors in ISAB Block. Set to None if block type
is 'sab'. is 'sab'.
dropouth : float dropouth : float
Dropout rate of each sublayer. The dropout rate of each sublayer.
dropouta : float dropouta : float
Dropout rate of attention heads. The dropout rate of attention heads.
""" """
def __init__(self, d_model, n_heads, d_head, d_ff, def __init__(self, d_model, n_heads, d_head, d_ff,
n_layers=1, block_type='sab', m=None, dropouth=0., dropouta=0.): n_layers=1, block_type='sab', m=None, dropouth=0., dropouta=0.):
...@@ -554,10 +746,10 @@ class SetTransformerEncoder(nn.Module): ...@@ -554,10 +746,10 @@ class SetTransformerEncoder(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The input graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)`, where :math:`N` is the
:math:`N` is the number of nodes in the graph. number of nodes in the graph.
Returns Returns
------- -------
...@@ -571,7 +763,11 @@ class SetTransformerEncoder(nn.Module): ...@@ -571,7 +763,11 @@ class SetTransformerEncoder(nn.Module):
class SetTransformerDecoder(nn.Module): class SetTransformerDecoder(nn.Module):
r"""The Decoder module in `Set Transformer: A Framework for Attention-based r"""
Description
-----------
The Decoder module in `Set Transformer: A Framework for Attention-based
Permutation-Invariant Neural Networks <https://arxiv.org/pdf/1810.00825.pdf>`__. Permutation-Invariant Neural Networks <https://arxiv.org/pdf/1810.00825.pdf>`__.
Parameters Parameters
...@@ -579,15 +775,15 @@ class SetTransformerDecoder(nn.Module): ...@@ -579,15 +775,15 @@ class SetTransformerDecoder(nn.Module):
d_model : int d_model : int
Hidden size of the model. Hidden size of the model.
num_heads : int num_heads : int
Number of heads. The number of heads.
d_head : int d_head : int
Hidden size of each head. Hidden size of each head.
d_ff : int d_ff : int
Kernel size in FFN (Positionwise Feed-Forward Network) layer. Kernel size in FFN (Positionwise Feed-Forward Network) layer.
n_layers : int n_layers : int
Number of layers. The number of layers.
k : int k : int
Number of seed vectors in PMA (Pooling by Multihead Attention) layer. The number of seed vectors in PMA (Pooling by Multihead Attention) layer.
dropouth : float dropouth : float
Dropout rate of each sublayer. Dropout rate of each sublayer.
dropouta : float dropouta : float
...@@ -615,16 +811,16 @@ class SetTransformerDecoder(nn.Module): ...@@ -615,16 +811,16 @@ class SetTransformerDecoder(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The input graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)`, where :math:`N` is the
:math:`N` is the number of nodes in the graph. number of nodes in the graph, and :math:`D` means the size of features.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(B, D)`, where The output feature with shape :math:`(B, D)`, where :math:`B` refers to
:math:`B` refers to the batch size. the batch size.
""" """
len_pma = graph.batch_num_nodes() len_pma = graph.batch_num_nodes()
len_sab = [self.k] * graph.batch_size len_sab = [self.k] * graph.batch_size
......
...@@ -104,20 +104,25 @@ class Identity(nn.Module): ...@@ -104,20 +104,25 @@ class Identity(nn.Module):
return x return x
class Sequential(nn.Sequential): class Sequential(nn.Sequential):
r"""A squential container for stacking graph neural network modules. r"""
We support two modes: sequentially apply GNN modules on the same graph or Description
a list of given graphs. In the second case, the number of graphs equals the -----------
A squential container for stacking graph neural network modules.
DGL supports two modes: sequentially apply GNN modules on 1) the same graph or
2) a list of given graphs. In the second case, the number of graphs equals the
number of modules inside this container. number of modules inside this container.
Parameters Parameters
---------- ----------
*args : *args :
Sub-modules of type torch.nn.Module, will be added to the container in Sub-modules of torch.nn.Module that will be added to the container in
the order they are passed in the constructor. the order by which they are passed in the constructor.
Examples Examples
-------- --------
The following example uses PyTorch backend.
Mode 1: sequentially apply GNN modules on the same graph Mode 1: sequentially apply GNN modules on the same graph
...@@ -146,16 +151,17 @@ class Sequential(nn.Sequential): ...@@ -146,16 +151,17 @@ class Sequential(nn.Sequential):
>>> e_feat = torch.rand(9, 4) >>> e_feat = torch.rand(9, 4)
>>> net(g, n_feat, e_feat) >>> net(g, n_feat, e_feat)
(tensor([[39.8597, 45.4542, 25.1877, 30.8086], (tensor([[39.8597, 45.4542, 25.1877, 30.8086],
[40.7095, 45.3985, 25.4590, 30.0134], [40.7095, 45.3985, 25.4590, 30.0134],
[40.7894, 45.2556, 25.5221, 30.4220]]), tensor([[80.3772, 89.7752, 50.7762, 60.5520], [40.7894, 45.2556, 25.5221, 30.4220]]),
[80.5671, 89.3736, 50.6558, 60.6418], tensor([[80.3772, 89.7752, 50.7762, 60.5520],
[80.4620, 89.5142, 50.3643, 60.3126], [80.5671, 89.3736, 50.6558, 60.6418],
[80.4817, 89.8549, 50.9430, 59.9108], [80.4620, 89.5142, 50.3643, 60.3126],
[80.2284, 89.6954, 50.0448, 60.1139], [80.4817, 89.8549, 50.9430, 59.9108],
[79.7846, 89.6882, 50.5097, 60.6213], [80.2284, 89.6954, 50.0448, 60.1139],
[80.2654, 90.2330, 50.2787, 60.6937], [79.7846, 89.6882, 50.5097, 60.6213],
[80.3468, 90.0341, 50.2062, 60.2659], [80.2654, 90.2330, 50.2787, 60.6937],
[80.0556, 90.2789, 50.2882, 60.5845]])) [80.3468, 90.0341, 50.2062, 60.2659],
[80.0556, 90.2789, 50.2882, 60.5845]]))
Mode 2: sequentially apply GNN modules on different graphs Mode 2: sequentially apply GNN modules on different graphs
...@@ -186,11 +192,14 @@ class Sequential(nn.Sequential): ...@@ -186,11 +192,14 @@ class Sequential(nn.Sequential):
[220.4007, 239.7365, 213.8648, 234.9637], [220.4007, 239.7365, 213.8648, 234.9637],
[196.4630, 207.6319, 184.2927, 208.7465]]) [196.4630, 207.6319, 184.2927, 208.7465]])
""" """
def __init__(self, *args): def __init__(self, *args):
super(Sequential, self).__init__(*args) super(Sequential, self).__init__(*args)
def forward(self, graph, *feats): def forward(self, graph, *feats):
r"""Sequentially apply modules to the input. r"""
Sequentially apply modules to the input.
Parameters Parameters
---------- ----------
...@@ -199,8 +208,8 @@ class Sequential(nn.Sequential): ...@@ -199,8 +208,8 @@ class Sequential(nn.Sequential):
*feats : *feats :
Input features. Input features.
The output of :math:`i`-th block should match that of the input The output of the :math:`i`-th module should match the input
of :math:`(i+1)`-th block. of the :math:`(i+1)`-th module in the sequential.
""" """
if isinstance(graph, list): if isinstance(graph, list):
for graph_i, module in zip(graph, self): for graph_i, module in zip(graph, self):
......
...@@ -6,9 +6,11 @@ __all__ = ['edge_softmax'] ...@@ -6,9 +6,11 @@ __all__ = ['edge_softmax']
def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
r"""Compute edge softmax. r"""
For a node :math:`i`, edge softmax is an operation of computing Description
-----------
Compute edge softmax. For a node :math:`i`, edge softmax is an operation that computes
.. math:: .. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})} a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
...@@ -22,15 +24,18 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): ...@@ -22,15 +24,18 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
softmax normalized by source nodes(i.e. :math:`ij` are outgoing edges of softmax normalized by source nodes(i.e. :math:`ij` are outgoing edges of
`i` in the formula). The previous case correspond to softmax in GAT and `i` in the formula). The previous case correspond to softmax in GAT and
Transformer, and the later case correspond to softmax in Capsule network. Transformer, and the later case correspond to softmax in Capsule network.
An example of using edge softmax is in
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where
the attention weights are computed with such an edge softmax operation.
Parameters Parameters
---------- ----------
gidx : HeteroGraphIndex graph : DGLGraph
The graph to perfor edge softmax on. The graph to perform edge softmax on.
logits : torch.Tensor logits : torch.Tensor
The input edge feature The input edge feature.
eids : torch.Tensor or ALL, optional eids : torch.Tensor or ALL, optional
Edges on which to apply edge softmax. If ALL, apply edge A tensor of edge index on which to apply edge softmax. If ALL, apply edge
softmax on all edges in the graph. Default: ALL. softmax on all edges in the graph. Default: ALL.
norm_by : str, could be `src` or `dst` norm_by : str, could be `src` or `dst`
Normalized by source nodes or destination nodes. Default: `dst`. Normalized by source nodes or destination nodes. Default: `dst`.
...@@ -38,62 +43,65 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): ...@@ -38,62 +43,65 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
Returns Returns
------- -------
Tensor Tensor
Softmax value Softmax value.
Notes Notes
----- -----
* Input shape: :math:`(E, *, 1)` where * means any number of * Input shape: :math:`(E, *, 1)` where * means any number of
additional dimensions, :math:`E` equals the length of eids. additional dimensions, :math:`E` equals the length of eids.
If the `eids` is ALL, :math:`E` equals the number of edges in
the graph.
* Return shape: :math:`(E, *, 1)` * Return shape: :math:`(E, *, 1)`
Examples Examples
-------- --------
The following example uses PyTorch backend.
>>> from dgl.ops import edge_softmax >>> from dgl.ops import edge_softmax
>>> import dgl >>> import dgl
>>> import torch as th >>> import torch as th
Create a :code:`DGLGraph` object and initialize its edge features. Create a :code:`DGLGraph` object g and initialize its edge features.
>>> g = dgl.DGLGraph() >>> g = dgl.DGLGraph()
>>> g.add_nodes(3) >>> g.add_nodes(3)
>>> g.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]) >>> g.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
>>> edata = th.ones(6, 1).float() >>> edata = th.ones(6, 1).float()
>>> edata >>> edata
tensor([[1.], tensor([[1.],
[1.], [1.],
[1.], [1.],
[1.], [1.],
[1.], [1.],
[1.]]) [1.]])
Apply edge softmax on g: Apply edge softmax on g:
>>> edge_softmax(g, edata) >>> edge_softmax(g, edata)
tensor([[1.0000], tensor([[1.0000],
[0.5000], [0.5000],
[0.3333], [0.3333],
[0.5000], [0.5000],
[0.3333], [0.3333],
[0.3333]]) [0.3333]])
Apply edge softmax on g normalized by source nodes: Apply edge softmax on g normalized by source nodes:
>>> edge_softmax(g, edata, norm_by='src') >>> edge_softmax(g, edata, norm_by='src')
tensor([[0.3333], tensor([[0.3333],
[0.3333], [0.3333],
[0.3333], [0.3333],
[0.5000], [0.5000],
[0.5000], [0.5000],
[1.0000]]) [1.0000]])
Apply edge softmax on first 4 edges of g: Apply edge softmax on first 4 edges of g:
>>> edge_softmax(g, edata[:4], th.Tensor([0,1,2,3])) >>> edge_softmax(g, edata[:4], th.Tensor([0,1,2,3]))
tensor([[1.0000], tensor([[1.0000],
[0.5000], [0.5000],
[1.0000], [1.0000],
[0.5000]]) [0.5000]])
""" """
return edge_softmax_internal(graph._graph, logits, return edge_softmax_internal(graph._graph, logits,
eids=eids, norm_by=norm_by) eids=eids, norm_by=norm_by)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment