Unverified Commit 80b22ad8 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Readability] Polish udf.py (#4534)



* Polish udf.py

* Update udf.py
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 7dba1991
"""User-defined function related data structures."""
from __future__ import absolute_import
class EdgeBatch(object):
"""The class that can represent a batch of edges.
......@@ -19,6 +20,7 @@ class EdgeBatch(object):
dst_data : dict[str, Tensor]
Dst node features.
"""
def __init__(self, graph, eid, etype, src_data, edge_data, dst_data):
self._graph = graph
self._eid = eid
......@@ -38,24 +40,25 @@ class EdgeBatch(object):
>>> import dgl
>>> import torch
>>> # Instantiate a graph and set a node feature 'h'
>>> # Instantiate a graph and set a node feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that retrieves the source node features for edges
>>> # Define a UDF that retrieves the source node features for edges.
>>> def edge_udf(edges):
>>> # edges.src['h'] is a tensor of shape (E, 1),
>>> # where E is the number of edges in the batch.
>>> return {'src': edges.src['h']}
>>> # Copy features from source nodes to edges
>>> # Copy features from source nodes to edges.
>>> g.apply_edges(edge_udf)
>>> g.edata['src']
tensor([[1.],
[1.],
[1.]])
>>> # Use edge UDF in message passing, which is equivalent to dgl.function.copy_u
>>> # Use edge UDF in message passing, which is equivalent to
>>> # dgl.function.copy_u.
>>> import dgl.function as fn
>>> g.update_all(edge_udf, fn.sum('src', 'h'))
>>> g.ndata['h']
......@@ -75,24 +78,25 @@ class EdgeBatch(object):
>>> import dgl
>>> import torch
>>> # Instantiate a graph and set a node feature 'h'
>>> # Instantiate a graph and set a node feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.tensor([[0.], [1.]])
>>> # Define a UDF that retrieves the destination node features for edges
>>> # Define a UDF that retrieves the destination node features for
>>> # edges.
>>> def edge_udf(edges):
>>> # edges.dst['h'] is a tensor of shape (E, 1),
>>> # where E is the number of edges in the batch.
>>> return {'dst': edges.dst['h']}
>>> # Copy features from destination nodes to edges
>>> # Copy features from destination nodes to edges.
>>> g.apply_edges(edge_udf)
>>> g.edata['dst']
tensor([[1.],
[1.],
[1.]])
>>> # Use edge UDF in message passing
>>> # Use edge UDF in message passing.
>>> import dgl.function as fn
>>> g.update_all(edge_udf, fn.sum('dst', 'h'))
>>> g.ndata['h']
......@@ -112,24 +116,25 @@ class EdgeBatch(object):
>>> import dgl
>>> import torch
>>> # Instantiate a graph and set an edge feature 'h'
>>> # Instantiate a graph and set an edge feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.edata['h'] = torch.tensor([[1.], [1.], [1.]])
>>> # Define a UDF that retrieves the feature 'h' for all edges
>>> # Define a UDF that retrieves the feature 'h' for all edges.
>>> def edge_udf(edges):
>>> # edges.data['h'] is a tensor of shape (E, 1),
>>> # where E is the number of edges in the batch.
>>> return {'data': edges.data['h']}
>>> # Make a copy of the feature with name 'data'
>>> # Make a copy of the feature with name 'data'.
>>> g.apply_edges(edge_udf)
>>> g.edata['data']
tensor([[1.],
[1.],
[1.]])
>>> # Use edge UDF in message passing, which is equivalent to dgl.function.copy_e
>>> # Use edge UDF in message passing, which is equivalent to
>>> # dgl.function.copy_e.
>>> import dgl.function as fn
>>> g.update_all(edge_udf, fn.sum('data', 'h'))
>>> g.ndata['h']
......@@ -139,13 +144,13 @@ class EdgeBatch(object):
return self._edge_data
def edges(self):
"""Return the edges in the batch
"""Return the edges in the batch.
Returns
-------
(U, V, EID) : (Tensor, Tensor, Tensor)
The edges in the batch. For each :math:`i`, :math:`(U[i], V[i])` is an edge
from :math:`U[i]` to :math:`V[i]` with ID :math:`EID[i]`.
The edges in the batch. For each :math:`i`, :math:`(U[i], V[i])` is
an edge from :math:`U[i]` to :math:`V[i]` with ID :math:`EID[i]`.
Examples
--------
......@@ -154,22 +159,23 @@ class EdgeBatch(object):
>>> import dgl
>>> import torch
>>> # Instantiate a graph
>>> # Instantiate a graph.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> # Define a UDF that retrieves and concatenates the end nodes of the edges
>>> # Define a UDF that retrieves and concatenates the end nodes of the
>>> # edges.
>>> def edge_udf(edges):
>>> src, dst, _ = edges.edges()
>>> return {'uv': torch.stack([src, dst], dim=1).float()}
>>> # Create a feature 'uv' with the end nodes of the edges
>>> # Create a feature 'uv' with the end nodes of the edges.
>>> g.apply_edges(edge_udf)
>>> g.edata['uv']
tensor([[0., 1.],
[1., 1.],
[1., 0.]])
>>> # Use edge UDF in message passing
>>> # Use edge UDF in message passing.
>>> import dgl.function as fn
>>> g.update_all(edge_udf, fn.sum('uv', 'h'))
>>> g.ndata['h']
......@@ -193,21 +199,21 @@ class EdgeBatch(object):
>>> import dgl
>>> import torch
>>> # Instantiate a graph
>>> # Instantiate a graph.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> # Define a UDF that returns one for each edge
>>> # Define a UDF that returns one for each edge.
>>> def edge_udf(edges):
>>> return {'h': torch.ones(edges.batch_size(), 1)}
>>> # Creates a feature 'h'
>>> # Creates a feature 'h'.
>>> g.apply_edges(edge_udf)
>>> g.edata['h']
tensor([[1.],
[1.],
[1.]])
>>> # Use edge UDF in message passing
>>> # Use edge UDF in message passing.
>>> import dgl.function as fn
>>> g.update_all(edge_udf, fn.sum('h', 'h'))
>>> g.ndata['h']
......@@ -231,6 +237,7 @@ class EdgeBatch(object):
destination node type) for this edge batch."""
return self._etype
class NodeBatch(object):
"""The class to represent a batch of nodes.
......@@ -247,6 +254,7 @@ class NodeBatch(object):
msgs : dict[str, Tensor], optional
Messages data.
"""
def __init__(self, graph, nodes, ntype, data, msgs=None):
self._graph = graph
self._nodes = nodes
......@@ -265,20 +273,20 @@ class NodeBatch(object):
>>> import dgl
>>> import torch
>>> # Instantiate a graph and set a feature 'h'
>>> # Instantiate a graph and set a feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received and the original feature
>>> # for each node
>>> # Define a UDF that computes the sum of the messages received and
>>> # the original feature for each node.
>>> def node_udf(nodes):
>>> # nodes.data['h'] is a tensor of shape (N, 1),
>>> # nodes.mailbox['m'] is a tensor of shape (N, D, 1),
>>> # where N is the number of nodes in the batch,
>>> # D is the number of messages received per node for this node batch
>>> # where N is the number of nodes in the batch, D is the number
>>> # of messages received per node for this node batch.
>>> return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing
>>> # Use node UDF in message passing.
>>> import dgl.function as fn
>>> g.update_all(fn.copy_u('h', 'm'), node_udf)
>>> g.ndata['h']
......@@ -298,20 +306,20 @@ class NodeBatch(object):
>>> import dgl
>>> import torch
>>> # Instantiate a graph and set a feature 'h'
>>> # Instantiate a graph and set a feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received and the original feature
>>> # for each node
>>> # Define a UDF that computes the sum of the messages received and
>>> # the original feature for each node.
>>> def node_udf(nodes):
>>> # nodes.data['h'] is a tensor of shape (N, 1),
>>> # nodes.mailbox['m'] is a tensor of shape (N, D, 1),
>>> # where N is the number of nodes in the batch,
>>> # D is the number of messages received per node for this node batch
>>> # where N is the number of nodes in the batch, D is the number
>>> # of messages received per node for this node batch.
>>> return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing
>>> # Use node UDF in message passing.
>>> import dgl.function as fn
>>> g.update_all(fn.copy_u('h', 'm'), node_udf)
>>> g.ndata['h']
......@@ -336,20 +344,21 @@ class NodeBatch(object):
>>> import dgl
>>> import torch
>>> # Instantiate a graph and set a feature 'h'
>>> # Instantiate a graph and set a feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received and the original ID
>>> # for each node
>>> # Define a UDF that computes the sum of the messages received and
>>> # the original ID for each node.
>>> def node_udf(nodes):
>>> # nodes.nodes() is a tensor of shape (N),
>>> # nodes.mailbox['m'] is a tensor of shape (N, D, 1),
>>> # where N is the number of nodes in the batch,
>>> # D is the number of messages received per node for this node batch
>>> return {'h': nodes.nodes().unsqueeze(-1).float() + nodes.mailbox['m'].sum(1)}
>>> # where N is the number of nodes in the batch, D is the number
>>> # of messages received per node for this node batch.
>>> return {'h': nodes.nodes().unsqueeze(-1).float()
>>> + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing
>>> # Use node UDF in message passing.
>>> import dgl.function as fn
>>> g.update_all(fn.copy_u('h', 'm'), node_udf)
>>> g.ndata['h']
......@@ -372,16 +381,17 @@ class NodeBatch(object):
>>> import dgl
>>> import torch
>>> # Instantiate a graph
>>> # Instantiate a graph.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received for each node
>>> # and increments the result by 1
>>> # Define a UDF that computes the sum of the messages received for
>>> # each node and increments the result by 1.
>>> def node_udf(nodes):
>>> return {'h': torch.ones(nodes.batch_size(), 1) + nodes.mailbox['m'].sum(1)}
>>> return {'h': torch.ones(nodes.batch_size(), 1)
>>> + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing
>>> # Use node UDF in message passing.
>>> import dgl.function as fn
>>> g.update_all(fn.copy_u('h', 'm'), node_udf)
>>> g.ndata['h']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment