"...pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "5d2d14538a87e45891609d172d7aa05a1b756068"
Unverified Commit 1b152bf5 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Doc & Tutorial] Fix dgl.batch and GAT (#1418)

* Update doc and tutorials

* Fix
parent 24dc71fc
...@@ -4009,6 +4009,7 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL): ...@@ -4009,6 +4009,7 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
The nodes and edges are re-indexed with a new id in the batched graph with the The nodes and edges are re-indexed with a new id in the batched graph with the
rule below: rule below:
====== ========== ======================== === ========================== ====== ========== ======================== === ==========================
item Graph 1 Graph 2 ... Graph k item Graph 1 Graph 2 ... Graph k
====== ========== ======================== === ========================== ====== ========== ======================== === ==========================
...@@ -4040,6 +4041,7 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL): ...@@ -4040,6 +4041,7 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
-------- --------
Create two :class:`~dgl.DGLGraph` objects. Create two :class:`~dgl.DGLGraph` objects.
**Instantiation:** **Instantiation:**
>>> import dgl >>> import dgl
>>> import torch as th >>> import torch as th
>>> g1 = dgl.DGLGraph() >>> g1 = dgl.DGLGraph()
...@@ -4052,13 +4054,17 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL): ...@@ -4052,13 +4054,17 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
>>> g2.add_edges([0, 2], [1, 1]) # Add edges 0 -> 1, 2 -> 1 >>> g2.add_edges([0, 2], [1, 1]) # Add edges 0 -> 1, 2 -> 1
>>> g2.ndata['hv'] = th.tensor([[2.], [3.], [4.]]) # Initialize node features >>> g2.ndata['hv'] = th.tensor([[2.], [3.], [4.]]) # Initialize node features
>>> g2.edata['he'] = th.tensor([[1.], [2.]]) # Initialize edge features >>> g2.edata['he'] = th.tensor([[1.], [2.]]) # Initialize edge features
Merge two :class:`~dgl.DGLGraph` objects into one :class:`DGLGraph` object. Merge two :class:`~dgl.DGLGraph` objects into one :class:`DGLGraph` object.
When merging a list of graphs, we can choose to include only a subset of the attributes. When merging a list of graphs, we can choose to include only a subset of the attributes.
>>> bg = dgl.batch([g1, g2], edge_attrs=None) >>> bg = dgl.batch([g1, g2], edge_attrs=None)
>>> bg.edata >>> bg.edata
{} {}
Below one can see that the nodes are re-indexed. The edges are re-indexed in Below one can see that the nodes are re-indexed. The edges are re-indexed in
the same way. the same way.
>>> bg.nodes() >>> bg.nodes()
tensor([0, 1, 2, 3, 4]) tensor([0, 1, 2, 3, 4])
>>> bg.ndata['hv'] >>> bg.ndata['hv']
...@@ -4067,14 +4073,17 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL): ...@@ -4067,14 +4073,17 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
[2.], [2.],
[3.], [3.],
[4.]]) [4.]])
**Property:** **Property:**
We can still get a brief summary of the graphs that constitute the batched graph. We can still get a brief summary of the graphs that constitute the batched graph.
>>> bg.batch_size >>> bg.batch_size
2 2
>>> bg.batch_num_nodes >>> bg.batch_num_nodes
[2, 3] [2, 3]
>>> bg.batch_num_edges >>> bg.batch_num_edges
[1, 2] [1, 2]
**Readout:** **Readout:**
Another common demand for graph neural networks is graph readout, which is a Another common demand for graph neural networks is graph readout, which is a
function that takes in the node attributes and/or edge attributes for a graph function that takes in the node attributes and/or edge attributes for a graph
...@@ -4082,20 +4091,26 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL): ...@@ -4082,20 +4091,26 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
DGL also supports performing readout for a batch of graphs at once. DGL also supports performing readout for a batch of graphs at once.
Below we take the built-in readout function :func:`sum_nodes` as an example, which Below we take the built-in readout function :func:`sum_nodes` as an example, which
sums over a particular kind of node attribute for each graph. sums over a particular kind of node attribute for each graph.
>>> dgl.sum_nodes(bg, 'hv') # Sum the node attribute 'hv' for each graph. >>> dgl.sum_nodes(bg, 'hv') # Sum the node attribute 'hv' for each graph.
tensor([[1.], # 0 + 1 tensor([[1.], # 0 + 1
[9.]]) # 2 + 3 + 4 [9.]]) # 2 + 3 + 4
**Message passing:** **Message passing:**
For message passing and related operations, batched :class:`DGLGraph` acts exactly For message passing and related operations, batched :class:`DGLGraph` acts exactly
the same as a single :class:`~dgl.DGLGraph` with batch size 1. the same as a single :class:`~dgl.DGLGraph` with batch size 1.
**Update Attributes:** **Update Attributes:**
Updating the attributes of the batched graph has no effect on the original graphs. Updating the attributes of the batched graph has no effect on the original graphs.
>>> bg.edata['he'] = th.zeros(3, 2) >>> bg.edata['he'] = th.zeros(3, 2)
>>> g2.edata['he'] >>> g2.edata['he']
tensor([[1.], tensor([[1.],
[2.]])} [2.]])}
Instead, we can decompose the batched graph back into a list of graphs and use them Instead, we can decompose the batched graph back into a list of graphs and use them
to replace the original graphs. to replace the original graphs.
>>> g1, g2 = dgl.unbatch(bg) # returns a list of DGLGraph objects >>> g1, g2 = dgl.unbatch(bg) # returns a list of DGLGraph objects
>>> g2.edata['he'] >>> g2.edata['he']
tensor([[0., 0.], tensor([[0., 0.],
......
...@@ -120,6 +120,13 @@ class GATLayer(nn.Module): ...@@ -120,6 +120,13 @@ class GATLayer(nn.Module):
self.fc = nn.Linear(in_dim, out_dim, bias=False) self.fc = nn.Linear(in_dim, out_dim, bias=False)
# equation (2) # equation (2)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False) self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.fc.weight, gain=gain)
nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
def edge_attention(self, edges): def edge_attention(self, edges):
# edge UDF for equation (2) # edge UDF for equation (2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment