Unverified Commit 80b22ad8 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Readability] Polish udf.py (#4534)



* Polish udf.py

* Update udf.py
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 7dba1991
"""User-defined function related data structures.""" """User-defined function related data structures."""
from __future__ import absolute_import from __future__ import absolute_import
class EdgeBatch(object): class EdgeBatch(object):
"""The class that can represent a batch of edges. """The class that can represent a batch of edges.
...@@ -19,6 +20,7 @@ class EdgeBatch(object): ...@@ -19,6 +20,7 @@ class EdgeBatch(object):
dst_data : dict[str, Tensor] dst_data : dict[str, Tensor]
Dst node features. Dst node features.
""" """
def __init__(self, graph, eid, etype, src_data, edge_data, dst_data): def __init__(self, graph, eid, etype, src_data, edge_data, dst_data):
self._graph = graph self._graph = graph
self._eid = eid self._eid = eid
...@@ -38,24 +40,25 @@ class EdgeBatch(object): ...@@ -38,24 +40,25 @@ class EdgeBatch(object):
>>> import dgl >>> import dgl
>>> import torch >>> import torch
>>> # Instantiate a graph and set a node feature 'h' >>> # Instantiate a graph and set a node feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1) >>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that retrieves the source node features for edges >>> # Define a UDF that retrieves the source node features for edges.
>>> def edge_udf(edges): >>> def edge_udf(edges):
>>> # edges.src['h'] is a tensor of shape (E, 1), >>> # edges.src['h'] is a tensor of shape (E, 1),
>>> # where E is the number of edges in the batch. >>> # where E is the number of edges in the batch.
>>> return {'src': edges.src['h']} >>> return {'src': edges.src['h']}
>>> # Copy features from source nodes to edges >>> # Copy features from source nodes to edges.
>>> g.apply_edges(edge_udf) >>> g.apply_edges(edge_udf)
>>> g.edata['src'] >>> g.edata['src']
tensor([[1.], tensor([[1.],
[1.], [1.],
[1.]]) [1.]])
>>> # Use edge UDF in message passing, which is equivalent to dgl.function.copy_u >>> # Use edge UDF in message passing, which is equivalent to
>>> # dgl.function.copy_u.
>>> import dgl.function as fn >>> import dgl.function as fn
>>> g.update_all(edge_udf, fn.sum('src', 'h')) >>> g.update_all(edge_udf, fn.sum('src', 'h'))
>>> g.ndata['h'] >>> g.ndata['h']
...@@ -75,24 +78,25 @@ class EdgeBatch(object): ...@@ -75,24 +78,25 @@ class EdgeBatch(object):
>>> import dgl >>> import dgl
>>> import torch >>> import torch
>>> # Instantiate a graph and set a node feature 'h' >>> # Instantiate a graph and set a node feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.tensor([[0.], [1.]]) >>> g.ndata['h'] = torch.tensor([[0.], [1.]])
>>> # Define a UDF that retrieves the destination node features for edges >>> # Define a UDF that retrieves the destination node features for
>>> # edges.
>>> def edge_udf(edges): >>> def edge_udf(edges):
>>> # edges.dst['h'] is a tensor of shape (E, 1), >>> # edges.dst['h'] is a tensor of shape (E, 1),
>>> # where E is the number of edges in the batch. >>> # where E is the number of edges in the batch.
>>> return {'dst': edges.dst['h']} >>> return {'dst': edges.dst['h']}
>>> # Copy features from destination nodes to edges >>> # Copy features from destination nodes to edges.
>>> g.apply_edges(edge_udf) >>> g.apply_edges(edge_udf)
>>> g.edata['dst'] >>> g.edata['dst']
tensor([[1.], tensor([[1.],
[1.], [1.],
[1.]]) [1.]])
>>> # Use edge UDF in message passing >>> # Use edge UDF in message passing.
>>> import dgl.function as fn >>> import dgl.function as fn
>>> g.update_all(edge_udf, fn.sum('dst', 'h')) >>> g.update_all(edge_udf, fn.sum('dst', 'h'))
>>> g.ndata['h'] >>> g.ndata['h']
...@@ -112,24 +116,25 @@ class EdgeBatch(object): ...@@ -112,24 +116,25 @@ class EdgeBatch(object):
>>> import dgl >>> import dgl
>>> import torch >>> import torch
>>> # Instantiate a graph and set an edge feature 'h' >>> # Instantiate a graph and set an edge feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.edata['h'] = torch.tensor([[1.], [1.], [1.]]) >>> g.edata['h'] = torch.tensor([[1.], [1.], [1.]])
>>> # Define a UDF that retrieves the feature 'h' for all edges >>> # Define a UDF that retrieves the feature 'h' for all edges.
>>> def edge_udf(edges): >>> def edge_udf(edges):
>>> # edges.data['h'] is a tensor of shape (E, 1), >>> # edges.data['h'] is a tensor of shape (E, 1),
>>> # where E is the number of edges in the batch. >>> # where E is the number of edges in the batch.
>>> return {'data': edges.data['h']} >>> return {'data': edges.data['h']}
>>> # Make a copy of the feature with name 'data' >>> # Make a copy of the feature with name 'data'.
>>> g.apply_edges(edge_udf) >>> g.apply_edges(edge_udf)
>>> g.edata['data'] >>> g.edata['data']
tensor([[1.], tensor([[1.],
[1.], [1.],
[1.]]) [1.]])
>>> # Use edge UDF in message passing, which is equivalent to dgl.function.copy_e >>> # Use edge UDF in message passing, which is equivalent to
>>> # dgl.function.copy_e.
>>> import dgl.function as fn >>> import dgl.function as fn
>>> g.update_all(edge_udf, fn.sum('data', 'h')) >>> g.update_all(edge_udf, fn.sum('data', 'h'))
>>> g.ndata['h'] >>> g.ndata['h']
...@@ -139,13 +144,13 @@ class EdgeBatch(object): ...@@ -139,13 +144,13 @@ class EdgeBatch(object):
return self._edge_data return self._edge_data
def edges(self): def edges(self):
"""Return the edges in the batch """Return the edges in the batch.
Returns Returns
------- -------
(U, V, EID) : (Tensor, Tensor, Tensor) (U, V, EID) : (Tensor, Tensor, Tensor)
The edges in the batch. For each :math:`i`, :math:`(U[i], V[i])` is an edge The edges in the batch. For each :math:`i`, :math:`(U[i], V[i])` is
from :math:`U[i]` to :math:`V[i]` with ID :math:`EID[i]`. an edge from :math:`U[i]` to :math:`V[i]` with ID :math:`EID[i]`.
Examples Examples
-------- --------
...@@ -154,22 +159,23 @@ class EdgeBatch(object): ...@@ -154,22 +159,23 @@ class EdgeBatch(object):
>>> import dgl >>> import dgl
>>> import torch >>> import torch
>>> # Instantiate a graph >>> # Instantiate a graph.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> # Define a UDF that retrieves and concatenates the end nodes of the edges >>> # Define a UDF that retrieves and concatenates the end nodes of the
>>> # edges.
>>> def edge_udf(edges): >>> def edge_udf(edges):
>>> src, dst, _ = edges.edges() >>> src, dst, _ = edges.edges()
>>> return {'uv': torch.stack([src, dst], dim=1).float()} >>> return {'uv': torch.stack([src, dst], dim=1).float()}
>>> # Create a feature 'uv' with the end nodes of the edges >>> # Create a feature 'uv' with the end nodes of the edges.
>>> g.apply_edges(edge_udf) >>> g.apply_edges(edge_udf)
>>> g.edata['uv'] >>> g.edata['uv']
tensor([[0., 1.], tensor([[0., 1.],
[1., 1.], [1., 1.],
[1., 0.]]) [1., 0.]])
>>> # Use edge UDF in message passing >>> # Use edge UDF in message passing.
>>> import dgl.function as fn >>> import dgl.function as fn
>>> g.update_all(edge_udf, fn.sum('uv', 'h')) >>> g.update_all(edge_udf, fn.sum('uv', 'h'))
>>> g.ndata['h'] >>> g.ndata['h']
...@@ -193,21 +199,21 @@ class EdgeBatch(object): ...@@ -193,21 +199,21 @@ class EdgeBatch(object):
>>> import dgl >>> import dgl
>>> import torch >>> import torch
>>> # Instantiate a graph >>> # Instantiate a graph.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> # Define a UDF that returns one for each edge >>> # Define a UDF that returns one for each edge.
>>> def edge_udf(edges): >>> def edge_udf(edges):
>>> return {'h': torch.ones(edges.batch_size(), 1)} >>> return {'h': torch.ones(edges.batch_size(), 1)}
>>> # Creates a feature 'h' >>> # Creates a feature 'h'.
>>> g.apply_edges(edge_udf) >>> g.apply_edges(edge_udf)
>>> g.edata['h'] >>> g.edata['h']
tensor([[1.], tensor([[1.],
[1.], [1.],
[1.]]) [1.]])
>>> # Use edge UDF in message passing >>> # Use edge UDF in message passing.
>>> import dgl.function as fn >>> import dgl.function as fn
>>> g.update_all(edge_udf, fn.sum('h', 'h')) >>> g.update_all(edge_udf, fn.sum('h', 'h'))
>>> g.ndata['h'] >>> g.ndata['h']
...@@ -231,6 +237,7 @@ class EdgeBatch(object): ...@@ -231,6 +237,7 @@ class EdgeBatch(object):
destination node type) for this edge batch.""" destination node type) for this edge batch."""
return self._etype return self._etype
class NodeBatch(object): class NodeBatch(object):
"""The class to represent a batch of nodes. """The class to represent a batch of nodes.
...@@ -247,6 +254,7 @@ class NodeBatch(object): ...@@ -247,6 +254,7 @@ class NodeBatch(object):
msgs : dict[str, Tensor], optional msgs : dict[str, Tensor], optional
Messages data. Messages data.
""" """
def __init__(self, graph, nodes, ntype, data, msgs=None): def __init__(self, graph, nodes, ntype, data, msgs=None):
self._graph = graph self._graph = graph
self._nodes = nodes self._nodes = nodes
...@@ -265,20 +273,20 @@ class NodeBatch(object): ...@@ -265,20 +273,20 @@ class NodeBatch(object):
>>> import dgl >>> import dgl
>>> import torch >>> import torch
>>> # Instantiate a graph and set a feature 'h' >>> # Instantiate a graph and set a feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1) >>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received and the original feature >>> # Define a UDF that computes the sum of the messages received and
>>> # for each node >>> # the original feature for each node.
>>> def node_udf(nodes): >>> def node_udf(nodes):
>>> # nodes.data['h'] is a tensor of shape (N, 1), >>> # nodes.data['h'] is a tensor of shape (N, 1),
>>> # nodes.mailbox['m'] is a tensor of shape (N, D, 1), >>> # nodes.mailbox['m'] is a tensor of shape (N, D, 1),
>>> # where N is the number of nodes in the batch, >>> # where N is the number of nodes in the batch, D is the number
>>> # D is the number of messages received per node for this node batch >>> # of messages received per node for this node batch.
>>> return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)} >>> return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing >>> # Use node UDF in message passing.
>>> import dgl.function as fn >>> import dgl.function as fn
>>> g.update_all(fn.copy_u('h', 'm'), node_udf) >>> g.update_all(fn.copy_u('h', 'm'), node_udf)
>>> g.ndata['h'] >>> g.ndata['h']
...@@ -298,20 +306,20 @@ class NodeBatch(object): ...@@ -298,20 +306,20 @@ class NodeBatch(object):
>>> import dgl >>> import dgl
>>> import torch >>> import torch
>>> # Instantiate a graph and set a feature 'h' >>> # Instantiate a graph and set a feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1) >>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received and the original feature >>> # Define a UDF that computes the sum of the messages received and
>>> # for each node >>> # the original feature for each node.
>>> def node_udf(nodes): >>> def node_udf(nodes):
>>> # nodes.data['h'] is a tensor of shape (N, 1), >>> # nodes.data['h'] is a tensor of shape (N, 1),
>>> # nodes.mailbox['m'] is a tensor of shape (N, D, 1), >>> # nodes.mailbox['m'] is a tensor of shape (N, D, 1),
>>> # where N is the number of nodes in the batch, >>> # where N is the number of nodes in the batch, D is the number
>>> # D is the number of messages received per node for this node batch >>> # of messages received per node for this node batch.
>>> return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)} >>> return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing >>> # Use node UDF in message passing.
>>> import dgl.function as fn >>> import dgl.function as fn
>>> g.update_all(fn.copy_u('h', 'm'), node_udf) >>> g.update_all(fn.copy_u('h', 'm'), node_udf)
>>> g.ndata['h'] >>> g.ndata['h']
...@@ -336,20 +344,21 @@ class NodeBatch(object): ...@@ -336,20 +344,21 @@ class NodeBatch(object):
>>> import dgl >>> import dgl
>>> import torch >>> import torch
>>> # Instantiate a graph and set a feature 'h' >>> # Instantiate a graph and set a feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1) >>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received and the original ID >>> # Define a UDF that computes the sum of the messages received and
>>> # for each node >>> # the original ID for each node.
>>> def node_udf(nodes): >>> def node_udf(nodes):
>>> # nodes.nodes() is a tensor of shape (N), >>> # nodes.nodes() is a tensor of shape (N),
>>> # nodes.mailbox['m'] is a tensor of shape (N, D, 1), >>> # nodes.mailbox['m'] is a tensor of shape (N, D, 1),
>>> # where N is the number of nodes in the batch, >>> # where N is the number of nodes in the batch, D is the number
>>> # D is the number of messages received per node for this node batch >>> # of messages received per node for this node batch.
>>> return {'h': nodes.nodes().unsqueeze(-1).float() + nodes.mailbox['m'].sum(1)} >>> return {'h': nodes.nodes().unsqueeze(-1).float()
>>> + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing >>> # Use node UDF in message passing.
>>> import dgl.function as fn >>> import dgl.function as fn
>>> g.update_all(fn.copy_u('h', 'm'), node_udf) >>> g.update_all(fn.copy_u('h', 'm'), node_udf)
>>> g.ndata['h'] >>> g.ndata['h']
...@@ -372,16 +381,17 @@ class NodeBatch(object): ...@@ -372,16 +381,17 @@ class NodeBatch(object):
>>> import dgl >>> import dgl
>>> import torch >>> import torch
>>> # Instantiate a graph >>> # Instantiate a graph.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1) >>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received for each node >>> # Define a UDF that computes the sum of the messages received for
>>> # and increments the result by 1 >>> # each node and increments the result by 1.
>>> def node_udf(nodes): >>> def node_udf(nodes):
>>> return {'h': torch.ones(nodes.batch_size(), 1) + nodes.mailbox['m'].sum(1)} >>> return {'h': torch.ones(nodes.batch_size(), 1)
>>> + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing >>> # Use node UDF in message passing.
>>> import dgl.function as fn >>> import dgl.function as fn
>>> g.update_all(fn.copy_u('h', 'm'), node_udf) >>> g.update_all(fn.copy_u('h', 'm'), node_udf)
>>> g.ndata['h'] >>> g.ndata['h']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment