Unverified Commit 988c8b20 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[DOC] Improve the docstring of GATConv and new APIs (#1339)

* upd

* upd

* upd

* upd

* upd
parent e6584043
"""Base graph class specialized for neural networks on graphs."""
# pylint: disable=too-many-lines
from __future__ import absolute_import
from collections import defaultdict
......@@ -1294,6 +1295,12 @@ class DGLGraph(DGLBaseGraph):
-------
Tensor
The parent node id array.
Notes
-----
The parent node id information is stored in ``_ID`` field in the
node frame of the graph, so please do not manually change
this field.
"""
if self._parent is None:
raise DGLError("We only support parent_nid for subgraphs.")
......@@ -1310,6 +1317,12 @@ class DGLGraph(DGLBaseGraph):
-------
Tensor
The parent edge id array.
Notes
-----
The parent edge id information is stored in ``_ID`` field in the
edge frame of the graph, so please do not manually change
this field.
"""
if self._parent is None:
raise DGLError("We only support parent_eid for subgraphs.")
......@@ -1322,6 +1335,52 @@ class DGLGraph(DGLBaseGraph):
----------
inplace : bool
If true, use inplace write (no gradient but faster)
Examples
--------
>>> import dgl
>>> import torch as th
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5) # Create a DGLGraph with 5 nodes
>>> g.add_edges([0,1,2,3,4], [1,2,3,4,0])
>>> subg.ndata['h'] = th.rand(4, 3)
>>> subg.edata['h'] = th.rand(3, 3)
>>> subg.ndata
{'_ID': tensor([0, 1, 3, 4]), 'h': tensor([[0.3803, 0.9351, 0.0611],
[0.6492, 0.4327, 0.3610],
[0.7471, 0.4257, 0.4130],
[0.9766, 0.6280, 0.6075]])}
>>> subg.edata
{'_ID': tensor([0, 3, 4]), 'h': tensor([[0.8192, 0.2409, 0.6278],
[0.9600, 0.3501, 0.8037],
[0.6521, 0.9029, 0.4901]])}
>>> g
DGLGraph(num_nodes=5, num_edges=5,
ndata_schemes={}
edata_schemes={})
>>> subg.copy_to_parent()
>>> g.ndata
{'h': tensor([[0.3803, 0.9351, 0.0611],
[0.6492, 0.4327, 0.3610],
[0.0000, 0.0000, 0.0000],
[0.7471, 0.4257, 0.4130],
[0.9766, 0.6280, 0.6075]])}
>>> g.edata
{'h': tensor([[0.8192, 0.2409, 0.6278],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.9600, 0.3501, 0.8037],
[0.6521, 0.9029, 0.4901]])}
Notes
-----
This API excludes the ``_ID`` field in both node frame and edge frame.
This being said if user take a subgraph ``sg`` of a graph ``g`` and
apply :func:`~dgl.copy_from_parent` on ``sg``, it would not polluate the
``_ID`` field of node/edge frame of ``g``.
See Also
--------
"""
if self._parent is None:
raise DGLError("We only support copy_to_parent for subgraphs.")
......@@ -1339,6 +1398,64 @@ class DGLGraph(DGLBaseGraph):
"""Copy node/edge features from the parent graph.
All old features will be removed.
Examples
--------
>>> import dgl
>>> import torch as th
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5) # Create a DGLGraph with 5 nodes
>>> g.add_edges([0,1,2,3,4], [1,2,3,4,0])
>>> g.ndata['h'] = th.rand(5, 3)
>>> g.ndata['h']
tensor([[0.3749, 0.5681, 0.4749],
[0.6312, 0.7955, 0.3682],
[0.0215, 0.0303, 0.0282],
[0.8840, 0.6842, 0.3645],
[0.9253, 0.8427, 0.6626]])
>>> g.edata['h'] = th.rand(5, 3)
>>> g.edata['h']
tensor([[0.0659, 0.8552, 0.9208],
[0.8238, 0.0332, 0.7864],
[0.1629, 0.4149, 0.1363],
[0.0648, 0.6582, 0.4400],
[0.4321, 0.1612, 0.7893]])
>>> g
DGLGraph(num_nodes=5, num_edges=5,
ndata_schemes={'h': Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={'h': Scheme(shape=(3,), dtype=torch.float32)})
>>> subg = g.subgraph([0,1,3,4]) # Take subgraph induced by node 0,1,3,4
>>> subg # '_ID' field records node/edge mapping
DGLGraph(num_nodes=4, num_edges=3,
ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
>>> subg.copy_from_parent()
>>> subg.ndata
{'h': tensor([[0.3749, 0.5681, 0.4749],
[0.6312, 0.7955, 0.3682],
[0.8840, 0.6842, 0.3645],
[0.9253, 0.8427, 0.6626]]), '_ID': tensor([0, 1, 3, 4])}
>>> subg.edata
{'h': tensor([[0.0659, 0.8552, 0.9208],
[0.0648, 0.6582, 0.4400],
[0.4321, 0.1612, 0.7893]]), '_ID': tensor([0, 3, 4])}
Notes
-----
This API excludes the ``_ID`` field in both node frame and edge frame.
This being said if user take a subgraph ``sg1`` of a subgraph ``sg``
whose ``_ID`` field in node/edge frame is not None and
apply :func:`~dgl.copy_from_parent` on ``sg1``, it would not polluate
the ``_ID`` field of node/edge frame of ``sg1``.
See Also
--------
subgraph
edge_subgraph
parent_nid
parent_eid
copy_to_parent
map_to_subgraph_nid
"""
if self._parent is None:
raise DGLError("We only support copy_from_parent for subgraphs.")
......@@ -1365,6 +1482,24 @@ class DGLGraph(DGLBaseGraph):
-------
tensor
The node ID array in the subgraph.
Examples
--------
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5)
>>> sg = g.subgrph([0,2,4])
>>> sg.map_to_subgraph([2,4])
tensor([1, 2])
See Also
--------
subgraph
edge_subgraph
parent_nid
parent_eid
copy_to_parent
copy_from_parent
"""
if self._parent is None:
raise DGLError("We only support map_to_subgraph_nid for subgraphs.")
......@@ -1376,6 +1511,55 @@ class DGLGraph(DGLBaseGraph):
"""Remove all batching information of the graph, and regard the current
graph as an independent graph rather then a batched graph.
Graph topology and attributes would not be affected.
User can change the structure of the flattened graph.
Examples
--------
>>> import dgl
>>> import torch as th
>>> g_list = []
>>> for _ in range(3) # Create three graphs, each with #nodes 4
>>> g = dgl.DGLGraph()
>>> g.add_nodes(4)
>>> g.add_edges([0,1,2,3], [1,2,3,0])
>>> g.ndata['h'] = th.rand(4, 3)
>>> g_list.append(g)
>>> bg = dgl.batch(g_list)
>>> bg.ndata
{'h': tensor([[0.0463, 0.1251, 0.5967],
[0.8633, 0.9812, 0.8601],
[0.7828, 0.3624, 0.7845],
[0.2169, 0.8761, 0.3237],
[0.1752, 0.1478, 0.5611],
[0.5279, 0.2556, 0.2304],
[0.8950, 0.8203, 0.5604],
[0.2999, 0.2946, 0.2676],
[0.3419, 0.2935, 0.6618],
[0.8137, 0.8927, 0.8953],
[0.6229, 0.7153, 0.5041],
[0.5659, 0.0612, 0.2351]])}
>>> bg.batch_size
3
>>> bg.batch_num_nodes
[4, 4, 4]
>>> bg.batch_num_edges
[4, 4, 4]
>>> bg.flatten()
>>> bg.batch_size
1
>>> bg.batch_num_nodes
[12]
>>> bg.batch_num_edges
[12]
>>> bg.remove_nodes([1,3,5,7,9,11])
>>> bg.ndata
{'h': tensor([[0.0463, 0.1251, 0.5967],
[0.7828, 0.3624, 0.7845],
[0.1752, 0.1478, 0.5611],
[0.8950, 0.8203, 0.5604],
[0.3419, 0.2935, 0.6618],
[0.6229, 0.7153, 0.5041]])}
"""
self._batch_num_nodes = None
self._batch_num_edges = None
......@@ -1384,6 +1568,41 @@ class DGLGraph(DGLBaseGraph):
"""Detach the current graph from its parent, and regard the current graph
as an independent graph rather then a subgraph.
Graph topology and attributes would not be affected.
User can change the structure of the detached graph.
Examples
--------
>>> import dgl
>>> import torch as th
>>> g = dgl.DGLGraph() # Graph 1
>>> g.add_nodes(5)
>>> g.ndata['h'] = th.rand(5, 3)
>>> g.ndata['h']
{'h': tensor([[0.9595, 0.7450, 0.5495],
[0.8253, 0.2902, 0.4393],
[0.3783, 0.4548, 0.6075],
[0.2323, 0.0936, 0.6580],
[0.1624, 0.3484, 0.3750]])}
>>> subg = g.subgraph([0,1,3]) # Create a subgraph
>>> subg.parent # Get the parent reference of subg
DGLGraph(num_nodes=5, num_edges=0,
ndata_schemes={'h': Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={})
>>> subg.copy_from_parent()
>>> subg.detach_parent() # Detach the subgraph from its parent
>>> subg.parent == None
True
>>> subg.add_nodes(1) # Change the structure of the subgraph
>>> subg
DGLGraph(num_nodes=4, num_edges=0,
ndata_schemes={'h': Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={})
>>> subg.ndata
{'h': tensor([[0.9595, 0.7450, 0.5495],
[0.8253, 0.2902, 0.4393],
[0.2323, 0.0936, 0.6580],
[0.0000, 0.0000, 0.0000]])}
"""
self._parent = None
self.ndata.pop(NID)
......@@ -3146,6 +3365,11 @@ class DGLGraph(DGLBaseGraph):
--------
subgraphs
edge_subgraph
parent_nid
parent_eid
copy_from_parent
copy_to_parent
map_to_subgraph_nid
"""
induced_nodes = utils.toindex(nodes)
sgi = self._graph.node_subgraph(induced_nodes)
......@@ -3172,6 +3396,11 @@ class DGLGraph(DGLBaseGraph):
See Also
--------
subgraph
parent_nid
parent_eid
copy_from_parent
copy_to_parent
map_to_subgraph_nid
"""
induced_nodes = [utils.toindex(n) for n in nodes]
sgis = self._graph.node_subgraphs(induced_nodes)
......@@ -3231,6 +3460,9 @@ class DGLGraph(DGLBaseGraph):
See Also
--------
subgraph
copy_from_parent
copy_to_parent
map_to_subgraph_nid
"""
induced_edges = utils.toindex(edges)
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=preserve_nodes)
......@@ -3696,8 +3928,22 @@ class DGLGraph(DGLBaseGraph):
def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
"""Batch a collection of :class:`~dgl.DGLGraph` and return a batched
:class:`DGLGraph` object that is independent of the :attr:`graph_list`, the batch
size of the returned graph is the length of :attr:`graph_list`.
:class:`DGLGraph` object that is independent of the :attr:`graph_list` so that
one can perform message passing and readout over a batch of graphs
simultaneously, the batch size of the returned graph is the length of
:attr:`graph_list`.
The nodes and edges are re-indexed with a new id in the batched graph with the
rule below:
====== ========== ======================== === ==========================
item Graph 1 Graph 2 ... Graph k
====== ========== ======================== === ==========================
raw id 0, ..., N1 0, ..., N2 ... ..., Nk
new id 0, ..., N1 N1 + 1, ..., N1 + N2 + 1 ... ..., N1 + ... + Nk + k - 1
====== ========== ======================== === ==========================
To modify the features in the batched graph has no effect on the original
graphs. See the examples below about how to work around.
Parameters
----------
......@@ -3716,6 +3962,71 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
DGLGraph
One single batched graph.
Examples
--------
Create two :class:`~dgl.DGLGraph` objects.
**Instantiation:**
>>> import dgl
>>> import torch as th
>>> g1 = dgl.DGLGraph()
>>> g1.add_nodes(2) # Add 2 nodes
>>> g1.add_edge(0, 1) # Add edge 0 -> 1
>>> g1.ndata['hv'] = th.tensor([[0.], [1.]]) # Initialize node features
>>> g1.edata['he'] = th.tensor([[0.]]) # Initialize edge features
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3) # Add 3 nodes
>>> g2.add_edges([0, 2], [1, 1]) # Add edges 0 -> 1, 2 -> 1
>>> g2.ndata['hv'] = th.tensor([[2.], [3.], [4.]]) # Initialize node features
>>> g2.edata['he'] = th.tensor([[1.], [2.]]) # Initialize edge features
Merge two :class:`~dgl.DGLGraph` objects into one :class:`DGLGraph` object.
When merging a list of graphs, we can choose to include only a subset of the attributes.
>>> bg = dgl.batch([g1, g2], edge_attrs=None)
>>> bg.edata
{}
Below one can see that the nodes are re-indexed. The edges are re-indexed in
the same way.
>>> bg.nodes()
tensor([0, 1, 2, 3, 4])
>>> bg.ndata['hv']
tensor([[0.],
[1.],
[2.],
[3.],
[4.]])
**Property:**
We can still get a brief summary of the graphs that constitute the batched graph.
>>> bg.batch_size
2
>>> bg.batch_num_nodes
[2, 3]
>>> bg.batch_num_edges
[1, 2]
**Readout:**
Another common demand for graph neural networks is graph readout, which is a
function that takes in the node attributes and/or edge attributes for a graph
and outputs a vector summarizing the information in the graph.
DGL also supports performing readout for a batch of graphs at once.
Below we take the built-in readout function :func:`sum_nodes` as an example, which
sums over a particular kind of node attribute for each graph.
>>> dgl.sum_nodes(bg, 'hv') # Sum the node attribute 'hv' for each graph.
tensor([[1.], # 0 + 1
[9.]]) # 2 + 3 + 4
**Message passing:**
For message passing and related operations, batched :class:`DGLGraph` acts exactly
the same as a single :class:`~dgl.DGLGraph` with batch size 1.
**Update Attributes:**
Updating the attributes of the batched graph has no effect on the original graphs.
>>> bg.edata['he'] = th.zeros(3, 2)
>>> g2.edata['he']
tensor([[1.],
[2.]])}
Instead, we can decompose the batched graph back into a list of graphs and use them
to replace the original graphs.
>>> g1, g2 = dgl.unbatch(bg) # returns a list of DGLGraph objects
>>> g2.edata['he']
tensor([[0., 0.],
[0., 0.]])}
See Also
--------
unbatch
......
......@@ -102,10 +102,20 @@ class GATConv(nn.Block):
graph = graph.local_var()
h = self.feat_drop(feat)
feat = self.fc(h).reshape(-1, self._num_heads, self._out_feats)
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
el = (feat * self.attn_l.data(feat.context)).sum(axis=-1).expand_dims(-1)
er = (feat * self.attn_r.data(feat.context)).sum(axis=-1).expand_dims(-1)
graph.ndata.update({'ft': feat, 'el': el, 'er': er})
# compute edge attention
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))
# compute softmax
......
......@@ -101,10 +101,20 @@ class GATConv(nn.Module):
graph = graph.local_var()
h = self.feat_drop(feat)
feat = self.fc(h).view(-1, self._num_heads, self._out_feats)
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
el = (feat * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.ndata.update({'ft': feat, 'el': el, 'er': er})
# compute edge attention
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))
# compute softmax
......
......@@ -103,10 +103,20 @@ class GATConv(layers.Layer):
graph = graph.local_var()
h = self.feat_drop(feat)
feat = tf.reshape(self.fc(h), (-1, self._num_heads, self._out_feats))
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
el = tf.reduce_sum(feat * self.attn_l, axis=-1, keepdims=True)
er = tf.reduce_sum(feat * self.attn_r, axis=-1, keepdims=True)
graph.ndata.update({'ft': feat, 'el': el, 'er': er})
# compute edge attention
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))
# compute softmax
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment