Unverified Commit 93ac29ce authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Unify DGLGraph, BatchedDGLGraph and DGLSubGraph (#1216)



* upd

* upd

* upd

* lint

* fix

* fix test

* fix

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd tutorial

* upd

* upd

* fix kg

* upd doc organization

* refresh test

* upd

* refactor doc

* fix lint
Co-authored-by: default avatarMinjie Wang <minjie.wang@nyu.edu>
parent 8874e830
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .gnn import GCNLayer, GATLayer from .gnn import GCNLayer, GATLayer
from ...batched_graph import BatchedDGLGraph, max_nodes from ...readout import max_nodes
from ...nn.pytorch import WeightAndSum from ...nn.pytorch import WeightAndSum
from ...contrib.deprecation import deprecated from ...contrib.deprecation import deprecated
...@@ -74,13 +74,13 @@ class BaseGNNClassifier(nn.Module): ...@@ -74,13 +74,13 @@ class BaseGNNClassifier(nn.Module):
self.soft_classifier = MLPBinaryClassifier( self.soft_classifier = MLPBinaryClassifier(
self.g_feats, classifier_hidden_feats, n_tasks, dropout) self.g_feats, classifier_hidden_feats, n_tasks, dropout)
def forward(self, bg, feats): def forward(self, g, feats):
"""Multi-task prediction for a batch of molecules """Multi-task prediction for a batch of molecules
Parameters Parameters
---------- ----------
bg : BatchedDGLGraph g : DGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M0) feats : FloatTensor of shape (N, M0)
Initial features for all atoms in the batch of molecules Initial features for all atoms in the batch of molecules
...@@ -91,18 +91,15 @@ class BaseGNNClassifier(nn.Module): ...@@ -91,18 +91,15 @@ class BaseGNNClassifier(nn.Module):
""" """
# Update atom features with GNNs # Update atom features with GNNs
for gnn in self.gnn_layers: for gnn in self.gnn_layers:
feats = gnn(bg, feats) feats = gnn(g, feats)
# Compute molecule features from atom features # Compute molecule features from atom features
h_g_sum = self.weighted_sum_readout(bg, feats) h_g_sum = self.weighted_sum_readout(g, feats)
with bg.local_scope(): with g.local_scope():
bg.ndata['h'] = feats g.ndata['h'] = feats
h_g_max = max_nodes(bg, 'h') h_g_max = max_nodes(g, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
h_g_max = h_g_max.unsqueeze(0)
h_g = torch.cat([h_g_sum, h_g_max], dim=1) h_g = torch.cat([h_g_sum, h_g_max], dim=1)
# Multi-task prediction # Multi-task prediction
......
...@@ -42,13 +42,13 @@ class GCNLayer(nn.Module): ...@@ -42,13 +42,13 @@ class GCNLayer(nn.Module):
if batchnorm: if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats) self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, bg, feats): def forward(self, g, feats):
"""Update atom representations """Update atom representations
Parameters Parameters
---------- ----------
bg : BatchedDGLGraph g : DGLGraph
Batched DGLGraphs for processing multiple molecules in parallel DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1) feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph * N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization * M1 is the input atom feature size, must match in_feats in initialization
...@@ -58,7 +58,7 @@ class GCNLayer(nn.Module): ...@@ -58,7 +58,7 @@ class GCNLayer(nn.Module):
new_feats : FloatTensor of shape (N, M2) new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size, must match out_feats in initialization * M2 is the output atom feature size, must match out_feats in initialization
""" """
new_feats = self.graph_conv(bg, feats) new_feats = self.graph_conv(g, feats)
if self.residual: if self.residual:
res_feats = self.activation(self.res_connection(feats)) res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats new_feats = new_feats + res_feats
......
...@@ -7,7 +7,7 @@ import torch.nn.functional as F ...@@ -7,7 +7,7 @@ import torch.nn.functional as F
import rdkit.Chem as Chem import rdkit.Chem as Chem
from ....batched_graph import batch, unbatch from ....graph import batch, unbatch
from ....contrib.deprecation import deprecated from ....contrib.deprecation import deprecated
from ....data.utils import get_download_dir from ....data.utils import get_download_dir
......
...@@ -63,7 +63,7 @@ class ChebConv(nn.Block): ...@@ -63,7 +63,7 @@ class ChebConv(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
......
...@@ -3,8 +3,7 @@ ...@@ -3,8 +3,7 @@
from mxnet import gluon, nd from mxnet import gluon, nd
from mxnet.gluon import nn from mxnet.gluon import nn
from ... import BatchedDGLGraph from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling', __all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
...@@ -24,7 +23,7 @@ class SumPooling(nn.Block): ...@@ -24,7 +23,7 @@ class SumPooling(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -33,9 +32,8 @@ class SumPooling(nn.Block): ...@@ -33,9 +32,8 @@ class SumPooling(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -61,7 +59,7 @@ class AvgPooling(nn.Block): ...@@ -61,7 +59,7 @@ class AvgPooling(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -70,9 +68,8 @@ class AvgPooling(nn.Block): ...@@ -70,9 +68,8 @@ class AvgPooling(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -98,7 +95,7 @@ class MaxPooling(nn.Block): ...@@ -98,7 +95,7 @@ class MaxPooling(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -107,9 +104,8 @@ class MaxPooling(nn.Block): ...@@ -107,9 +104,8 @@ class MaxPooling(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -139,7 +135,7 @@ class SortPooling(nn.Block): ...@@ -139,7 +135,7 @@ class SortPooling(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -148,9 +144,8 @@ class SortPooling(nn.Block): ...@@ -148,9 +144,8 @@ class SortPooling(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(k * D)` (if The output feature with shape :math:`(B, k * D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, k * D)`.
""" """
# Sort the feature of each node in ascending order. # Sort the feature of each node in ascending order.
with graph.local_scope(): with graph.local_scope():
...@@ -159,10 +154,7 @@ class SortPooling(nn.Block): ...@@ -159,10 +154,7 @@ class SortPooling(nn.Block):
# Sort nodes according to their last features. # Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k)[0].reshape( ret = topk_nodes(graph, 'h', self.k)[0].reshape(
-1, self.k * feat.shape[-1]) -1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph): return ret
return ret
else:
return ret.squeeze(axis=0)
def __repr__(self): def __repr__(self):
return 'SortPooling(k={})'.format(self.k) return 'SortPooling(k={})'.format(self.k)
...@@ -195,7 +187,7 @@ class GlobalAttentionPooling(nn.Block): ...@@ -195,7 +187,7 @@ class GlobalAttentionPooling(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -204,9 +196,8 @@ class GlobalAttentionPooling(nn.Block): ...@@ -204,9 +196,8 @@ class GlobalAttentionPooling(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
with graph.local_scope(): with graph.local_scope():
gate = self.gate_nn(feat) gate = self.gate_nn(feat)
...@@ -263,7 +254,7 @@ class Set2Set(nn.Block): ...@@ -263,7 +254,7 @@ class Set2Set(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -272,14 +263,11 @@ class Set2Set(nn.Block): ...@@ -272,14 +263,11 @@ class Set2Set(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
with graph.local_scope(): with graph.local_scope():
batch_size = 1 batch_size = graph.batch_size
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
h = (nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context), h = (nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context),
nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context)) nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context))
...@@ -288,23 +276,14 @@ class Set2Set(nn.Block): ...@@ -288,23 +276,14 @@ class Set2Set(nn.Block):
for _ in range(self.n_iters): for _ in range(self.n_iters):
q, h = self.lstm(q_star.expand_dims(axis=0), h) q, h = self.lstm(q_star.expand_dims(axis=0), h)
q = q.reshape((batch_size, self.input_dim)) q = q.reshape((batch_size, self.input_dim))
e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True) e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True)
graph.ndata['e'] = e graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e') alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r') readout = sum_nodes(graph, 'r')
if readout.ndim == 1: # graph is not a BatchedDGLGraph
readout = readout.expand_dims(0)
q_star = nd.concat(q, readout, dim=-1) q_star = nd.concat(q, readout, dim=-1)
if isinstance(graph, BatchedDGLGraph): return q_star
return q_star
else:
return q_star.squeeze(axis=0)
def __repr__(self): def __repr__(self):
summary = 'Set2Set(' summary = 'Set2Set('
......
...@@ -203,7 +203,7 @@ class AtomicConv(nn.Module): ...@@ -203,7 +203,7 @@ class AtomicConv(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
Topology based on which message passing is performed. Topology based on which message passing is performed.
feat : Float32 tensor of shape (V, 1) feat : Float32 tensor of shape (V, 1)
Initial node features, which are atomic numbers in the paper. Initial node features, which are atomic numbers in the paper.
......
...@@ -68,7 +68,7 @@ class ChebConv(nn.Module): ...@@ -68,7 +68,7 @@ class ChebConv(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
......
...@@ -4,9 +4,8 @@ import torch as th ...@@ -4,9 +4,8 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from ... import BatchedDGLGraph
from ...backend import pytorch as F from ...backend import pytorch as F
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\ from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes softmax_nodes, topk_nodes
...@@ -29,7 +28,7 @@ class SumPooling(nn.Module): ...@@ -29,7 +28,7 @@ class SumPooling(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -38,9 +37,8 @@ class SumPooling(nn.Module): ...@@ -38,9 +37,8 @@ class SumPooling(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -62,7 +60,7 @@ class AvgPooling(nn.Module): ...@@ -62,7 +60,7 @@ class AvgPooling(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -71,9 +69,8 @@ class AvgPooling(nn.Module): ...@@ -71,9 +69,8 @@ class AvgPooling(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -95,7 +92,7 @@ class MaxPooling(nn.Module): ...@@ -95,7 +92,7 @@ class MaxPooling(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -104,9 +101,8 @@ class MaxPooling(nn.Module): ...@@ -104,9 +101,8 @@ class MaxPooling(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -132,7 +128,7 @@ class SortPooling(nn.Module): ...@@ -132,7 +128,7 @@ class SortPooling(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -141,9 +137,8 @@ class SortPooling(nn.Module): ...@@ -141,9 +137,8 @@ class SortPooling(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(k * D)` (if The output feature with shape :math:`(B, k * D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, k * D)`.
""" """
with graph.local_scope(): with graph.local_scope():
# Sort the feature of each node in ascending order. # Sort the feature of each node in ascending order.
...@@ -152,10 +147,7 @@ class SortPooling(nn.Module): ...@@ -152,10 +147,7 @@ class SortPooling(nn.Module):
# Sort nodes according to their last features. # Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k, idx=-1)[0].view( ret = topk_nodes(graph, 'h', self.k, idx=-1)[0].view(
-1, self.k * feat.shape[-1]) -1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph): return ret
return ret
else:
return ret.squeeze(0)
class GlobalAttentionPooling(nn.Module): class GlobalAttentionPooling(nn.Module):
...@@ -193,9 +185,8 @@ class GlobalAttentionPooling(nn.Module): ...@@ -193,9 +185,8 @@ class GlobalAttentionPooling(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
with graph.local_scope(): with graph.local_scope():
gate = self.gate_nn(feat) gate = self.gate_nn(feat)
...@@ -257,7 +248,7 @@ class Set2Set(nn.Module): ...@@ -257,7 +248,7 @@ class Set2Set(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -266,14 +257,11 @@ class Set2Set(nn.Module): ...@@ -266,14 +257,11 @@ class Set2Set(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
with graph.local_scope(): with graph.local_scope():
batch_size = 1 batch_size = graph.batch_size
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)), h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
feat.new_zeros((self.n_layers, batch_size, self.input_dim))) feat.new_zeros((self.n_layers, batch_size, self.input_dim)))
...@@ -283,23 +271,14 @@ class Set2Set(nn.Module): ...@@ -283,23 +271,14 @@ class Set2Set(nn.Module):
for _ in range(self.n_iters): for _ in range(self.n_iters):
q, h = self.lstm(q_star.unsqueeze(0), h) q, h = self.lstm(q_star.unsqueeze(0), h)
q = q.view(batch_size, self.input_dim) q = q.view(batch_size, self.input_dim)
e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True) e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True)
graph.ndata['e'] = e graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e') alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r') readout = sum_nodes(graph, 'r')
if readout.dim() == 1: # graph is not a BatchedDGLGraph
readout = readout.unsqueeze(0)
q_star = th.cat([q, readout], dim=-1) q_star = th.cat([q, readout], dim=-1)
if isinstance(graph, BatchedDGLGraph): return q_star
return q_star
else:
return q_star.squeeze(0)
def extra_repr(self): def extra_repr(self):
"""Set the extra representation of the module. """Set the extra representation of the module.
...@@ -574,7 +553,7 @@ class SetTransformerEncoder(nn.Module): ...@@ -574,7 +553,7 @@ class SetTransformerEncoder(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -585,14 +564,9 @@ class SetTransformerEncoder(nn.Module): ...@@ -585,14 +564,9 @@ class SetTransformerEncoder(nn.Module):
torch.Tensor torch.Tensor
The output feature with shape :math:`(N, D)`. The output feature with shape :math:`(N, D)`.
""" """
if isinstance(graph, BatchedDGLGraph): lengths = graph.batch_num_nodes
lengths = graph.batch_num_nodes
else:
lengths = [graph.number_of_nodes()]
for layer in self.layers: for layer in self.layers:
feat = layer(feat, lengths) feat = layer(feat, lengths)
return feat return feat
...@@ -640,7 +614,7 @@ class SetTransformerDecoder(nn.Module): ...@@ -640,7 +614,7 @@ class SetTransformerDecoder(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -649,25 +623,15 @@ class SetTransformerDecoder(nn.Module): ...@@ -649,25 +623,15 @@ class SetTransformerDecoder(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
if isinstance(graph, BatchedDGLGraph): len_pma = graph.batch_num_nodes
len_pma = graph.batch_num_nodes len_sab = [self.k] * graph.batch_size
len_sab = [self.k] * graph.batch_size
else:
len_pma = [graph.number_of_nodes()]
len_sab = [self.k]
feat = self.pma(feat, len_pma) feat = self.pma(feat, len_pma)
for layer in self.layers: for layer in self.layers:
feat = layer(feat, len_sab) feat = layer(feat, len_sab)
return feat.view(graph.batch_size, self.k * self.d_model)
if isinstance(graph, BatchedDGLGraph):
return feat.view(graph.batch_size, self.k * self.d_model)
else:
return feat.view(self.k * self.d_model)
class WeightAndSum(nn.Module): class WeightAndSum(nn.Module):
...@@ -686,13 +650,13 @@ class WeightAndSum(nn.Module): ...@@ -686,13 +650,13 @@ class WeightAndSum(nn.Module):
nn.Sigmoid() nn.Sigmoid()
) )
def forward(self, bg, feats): def forward(self, g, feats):
"""Compute molecule representations out of atom representations """Compute molecule representations out of atom representations
Parameters Parameters
---------- ----------
bg : BatchedDGLGraph g : DGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats) feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules Representations for all atoms in the molecules
* N is the total number of atoms in all molecules * N is the total number of atoms in all molecules
...@@ -702,9 +666,9 @@ class WeightAndSum(nn.Module): ...@@ -702,9 +666,9 @@ class WeightAndSum(nn.Module):
FloatTensor of shape (B, self.in_feats) FloatTensor of shape (B, self.in_feats)
Representations for B molecules Representations for B molecules
""" """
with bg.local_scope(): with g.local_scope():
bg.ndata['h'] = feats g.ndata['h'] = feats
bg.ndata['w'] = self.atom_weighting(bg.ndata['h']) g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(bg, 'h', 'w') h_g_sum = sum_nodes(g, 'h', 'w')
return h_g_sum return h_g_sum
...@@ -4,8 +4,7 @@ import tensorflow as tf ...@@ -4,8 +4,7 @@ import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
from ... import BatchedDGLGraph from ...readout import sum_nodes, mean_nodes, max_nodes, \
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, \
softmax_nodes, topk_nodes softmax_nodes, topk_nodes
...@@ -29,7 +28,7 @@ class SumPooling(layers.Layer): ...@@ -29,7 +28,7 @@ class SumPooling(layers.Layer):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -38,9 +37,8 @@ class SumPooling(layers.Layer): ...@@ -38,9 +37,8 @@ class SumPooling(layers.Layer):
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -63,7 +61,7 @@ class AvgPooling(layers.Layer): ...@@ -63,7 +61,7 @@ class AvgPooling(layers.Layer):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -72,9 +70,8 @@ class AvgPooling(layers.Layer): ...@@ -72,9 +70,8 @@ class AvgPooling(layers.Layer):
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -97,7 +94,7 @@ class MaxPooling(layers.Layer): ...@@ -97,7 +94,7 @@ class MaxPooling(layers.Layer):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -106,9 +103,8 @@ class MaxPooling(layers.Layer): ...@@ -106,9 +103,8 @@ class MaxPooling(layers.Layer):
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -135,7 +131,7 @@ class SortPooling(layers.Layer): ...@@ -135,7 +131,7 @@ class SortPooling(layers.Layer):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -144,9 +140,8 @@ class SortPooling(layers.Layer): ...@@ -144,9 +140,8 @@ class SortPooling(layers.Layer):
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature with shape :math:`(k * D)` (if The output feature with shape :math:`(B, k * D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, k * D)`.
""" """
with graph.local_scope(): with graph.local_scope():
# Sort the feature of each node in ascending order. # Sort the feature of each node in ascending order.
...@@ -155,10 +150,7 @@ class SortPooling(layers.Layer): ...@@ -155,10 +150,7 @@ class SortPooling(layers.Layer):
# Sort nodes according to their last features. # Sort nodes according to their last features.
ret = tf.reshape(topk_nodes(graph, 'h', self.k, idx=-1)[0], ( ret = tf.reshape(topk_nodes(graph, 'h', self.k, idx=-1)[0], (
-1, self.k * feat.shape[-1])) -1, self.k * feat.shape[-1]))
if isinstance(graph, BatchedDGLGraph): return ret
return ret
else:
return tf.squeeze(ret, 0)
class GlobalAttentionPooling(layers.Layer): class GlobalAttentionPooling(layers.Layer):
...@@ -197,9 +189,8 @@ class GlobalAttentionPooling(layers.Layer): ...@@ -197,9 +189,8 @@ class GlobalAttentionPooling(layers.Layer):
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
with graph.local_scope(): with graph.local_scope():
gate = self.gate_nn(feat) gate = self.gate_nn(feat)
...@@ -234,13 +225,13 @@ class WeightAndSum(layers.Layer): ...@@ -234,13 +225,13 @@ class WeightAndSum(layers.Layer):
layers.Activation(tf.nn.sigmoid) layers.Activation(tf.nn.sigmoid)
) )
def call(self, bg, feats): def call(self, g, feats):
"""Compute molecule representations out of atom representations """Compute molecule representations out of atom representations
Parameters Parameters
---------- ----------
bg : BatchedDGLGraph g : DGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats) feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules Representations for all atoms in the molecules
* N is the total number of atoms in all molecules * N is the total number of atoms in all molecules
...@@ -250,9 +241,9 @@ class WeightAndSum(layers.Layer): ...@@ -250,9 +241,9 @@ class WeightAndSum(layers.Layer):
FloatTensor of shape (B, self.in_feats) FloatTensor of shape (B, self.in_feats)
Representations for B molecules Representations for B molecules
""" """
with bg.local_scope(): with g.local_scope():
bg.ndata['h'] = feats g.ndata['h'] = feats
bg.ndata['w'] = self.atom_weighting(bg.ndata['h']) g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(bg, 'h', 'w') h_g_sum = sum_nodes(g, 'h', 'w')
return h_g_sum return h_g_sum
"""Classes and functions for batching multiple graphs together.""" """Classes and functions for batching multiple graphs together."""
from __future__ import absolute_import from __future__ import absolute_import
from collections.abc import Iterable
import numpy as np import numpy as np
from .base import ALL, is_all, DGLError from .base import DGLError
from .frame import FrameRef, Frame
from .graph import DGLGraph
from . import graph_index as gi
from . import backend as F from . import backend as F
from . import utils
__all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split', __all__ = ['sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges',
'sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges',
'max_nodes', 'max_edges', 'softmax_nodes', 'softmax_edges', 'max_nodes', 'max_edges', 'softmax_nodes', 'softmax_edges',
'broadcast_nodes', 'broadcast_edges', 'topk_nodes', 'topk_edges'] 'broadcast_nodes', 'broadcast_edges', 'topk_nodes', 'topk_edges']
class BatchedDGLGraph(DGLGraph):
"""Class for batched DGL graphs.
A :class:`BatchedDGLGraph` basically merges a list of small graphs into a giant
graph so that one can perform message passing and readout over a batch of graphs
simultaneously.
The nodes and edges are re-indexed with a new id in the batched graph with the
rule below:
====== ========== ======================== === ==========================
item Graph 1 Graph 2 ... Graph k
====== ========== ======================== === ==========================
raw id 0, ..., N1 0, ..., N2 ... ..., Nk
new id 0, ..., N1 N1 + 1, ..., N1 + N2 + 1 ... ..., N1 + ... + Nk + k - 1
====== ========== ======================== === ==========================
The batched graph is read-only, i.e. one cannot further add nodes and edges.
A ``RuntimeError`` will be raised if one attempts.
To modify the features in :class:`BatchedDGLGraph` has no effect on the original
graphs. See the examples below about how to work around.
Parameters
----------
graph_list : iterable
A collection of :class:`~dgl.DGLGraph` objects to be batched.
node_attrs : None, str or iterable, optional
The node attributes to be batched. If ``None``, the :class:`BatchedDGLGraph` object
will not have any node attributes. By default, all node attributes will be batched.
An error will be raised if graphs having nodes have different attributes. If ``str``
or ``iterable``, this should specify exactly what node attributes to be batched.
edge_attrs : None, str or iterable, optional
Same as for the case of :attr:`node_attrs`
Examples
--------
Create two :class:`~dgl.DGLGraph` objects.
**Instantiation:**
>>> import dgl
>>> import torch as th
>>> g1 = dgl.DGLGraph()
>>> g1.add_nodes(2) # Add 2 nodes
>>> g1.add_edge(0, 1) # Add edge 0 -> 1
>>> g1.ndata['hv'] = th.tensor([[0.], [1.]]) # Initialize node features
>>> g1.edata['he'] = th.tensor([[0.]]) # Initialize edge features
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3) # Add 3 nodes
>>> g2.add_edges([0, 2], [1, 1]) # Add edges 0 -> 1, 2 -> 1
>>> g2.ndata['hv'] = th.tensor([[2.], [3.], [4.]]) # Initialize node features
>>> g2.edata['he'] = th.tensor([[1.], [2.]]) # Initialize edge features
Merge two :class:`~dgl.DGLGraph` objects into one :class:`BatchedDGLGraph` object.
When merging a list of graphs, we can choose to include only a subset of the attributes.
>>> bg = dgl.batch([g1, g2], edge_attrs=None)
>>> bg.edata
{}
Below one can see that the nodes are re-indexed. The edges are re-indexed in
the same way.
>>> bg.nodes()
tensor([0, 1, 2, 3, 4])
>>> bg.ndata['hv']
tensor([[0.],
[1.],
[2.],
[3.],
[4.]])
**Property:**
We can still get a brief summary of the graphs that constitute the batched graph.
>>> bg.batch_size
2
>>> bg.batch_num_nodes
[2, 3]
>>> bg.batch_num_edges
[1, 2]
**Readout:**
Another common demand for graph neural networks is graph readout, which is a
function that takes in the node attributes and/or edge attributes for a graph
and outputs a vector summarizing the information in the graph.
:class:`BatchedDGLGraph` also supports performing readout for a batch of graphs at once.
Below we take the built-in readout function :func:`sum_nodes` as an example, which
sums over a particular kind of node attribute for each graph.
>>> dgl.sum_nodes(bg, 'hv') # Sum the node attribute 'hv' for each graph.
tensor([[1.], # 0 + 1
[9.]]) # 2 + 3 + 4
**Message passing:**
For message passing and related operations, :class:`BatchedDGLGraph` acts exactly
the same as :class:`~dgl.DGLGraph`.
**Update Attributes:**
Updating the attributes of the batched graph has no effect on the original graphs.
>>> bg.edata['he'] = th.zeros(3, 2)
>>> g2.edata['he']
tensor([[1.],
[2.]])}
Instead, we can decompose the batched graph back into a list of graphs and use them
to replace the original graphs.
>>> g1, g2 = dgl.unbatch(bg) # returns a list of DGLGraph objects
>>> g2.edata['he']
tensor([[0., 0.],
[0., 0.]])}
"""
def __init__(self, graph_list, node_attrs, edge_attrs):
def _get_num_item_and_attr_types(g, mode):
if mode == 'node':
num_items = g.number_of_nodes()
attr_types = set(g.node_attr_schemes().keys())
elif mode == 'edge':
num_items = g.number_of_edges()
attr_types = set(g.edge_attr_schemes().keys())
return num_items, attr_types
def _init_attrs(attrs, mode):
if attrs is None:
return []
elif is_all(attrs):
attrs = set()
# Check if at least a graph has mode items and associated features.
for i, g in enumerate(graph_list):
g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode)
if g_num_items > 0 and len(g_attrs) > 0:
attrs = g_attrs
ref_g_index = i
break
# Check if all the graphs with mode items have the same associated features.
if len(attrs) > 0:
for i, g in enumerate(graph_list):
g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode)
if g_attrs != attrs and g_num_items > 0:
raise ValueError('Expect graph {0} and {1} to have the same {2} '
'attributes when {2}_attrs=ALL, got {3} and {4}.'
.format(ref_g_index, i, mode, attrs, g_attrs))
return attrs
elif isinstance(attrs, str):
return [attrs]
elif isinstance(attrs, Iterable):
return attrs
else:
raise ValueError('Expected {} attrs to be of type None str or Iterable, '
'got type {}'.format(mode, type(attrs)))
node_attrs = _init_attrs(node_attrs, 'node')
edge_attrs = _init_attrs(edge_attrs, 'edge')
# create batched graph index
batched_index = gi.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames
if len(node_attrs) == 0:
batched_node_frame = FrameRef(Frame(num_rows=batched_index.number_of_nodes()))
else:
# NOTE: following code will materialize the columns of the input graphs.
cols = {key: F.cat([gr._node_frame[key] for gr in graph_list
if gr.number_of_nodes() > 0], dim=0)
for key in node_attrs}
batched_node_frame = FrameRef(Frame(cols))
if len(edge_attrs) == 0:
batched_edge_frame = FrameRef(Frame(num_rows=batched_index.number_of_edges()))
else:
cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list
if gr.number_of_edges() > 0], dim=0)
for key in edge_attrs}
batched_edge_frame = FrameRef(Frame(cols))
super(BatchedDGLGraph, self).__init__(graph_data=batched_index,
node_frame=batched_node_frame,
edge_frame=batched_edge_frame)
# extra members
self._batch_size = 0
self._batch_num_nodes = []
self._batch_num_edges = []
for grh in graph_list:
if isinstance(grh, BatchedDGLGraph):
# handle the input is again a batched graph.
self._batch_size += grh._batch_size
self._batch_num_nodes += grh._batch_num_nodes
self._batch_num_edges += grh._batch_num_edges
else:
self._batch_size += 1
self._batch_num_nodes.append(grh.number_of_nodes())
self._batch_num_edges.append(grh.number_of_edges())
@property
def batch_size(self):
"""Number of graphs in this batch.
Returns
-------
int
Number of graphs in this batch."""
return self._batch_size
@property
def batch_num_nodes(self):
"""Number of nodes of each graph in this batch.
Returns
-------
list
Number of nodes of each graph in this batch."""
return self._batch_num_nodes
@property
def batch_num_edges(self):
"""Number of edges of each graph in this batch.
Returns
-------
list
Number of edges of each graph in this batch."""
return self._batch_num_edges
# override APIs
def add_nodes(self, num, data=None):
"""Add nodes. Disabled because BatchedDGLGraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, data=None):
"""Add one edge. Disabled because BatchedDGLGraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, data=None):
"""Add many edges. Disabled because BatchedDGLGraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
# new APIs
def __getitem__(self, idx):
"""Slice the batch and return the batch of graphs specified by the idx."""
# TODO
raise NotImplementedError
def __setitem__(self, idx, val):
"""Set the value of the slice. The graph size cannot be changed."""
# TODO
raise NotImplementedError
def split(graph_batch, num_or_size_splits): # pylint: disable=unused-argument
"""Split the batch."""
# TODO(minjie): could follow torch.split syntax
raise NotImplementedError
def unbatch(graph):
"""Return the list of graphs in this batch.
Parameters
----------
graph : BatchedDGLGraph
The batched graph.
Returns
-------
list
A list of :class:`~dgl.DGLGraph` objects whose attributes are obtained
by partitioning the attributes of the :attr:`graph`. The length of the
list is the same as the batch size of :attr:`graph`.
Notes
-----
Unbatching will break each field tensor of the batched graph into smaller
partitions.
For simpler tasks such as node/edge state aggregation, try to use
readout functions.
See Also
--------
batch
"""
assert isinstance(graph, BatchedDGLGraph)
bsize = graph.batch_size
bnn = graph.batch_num_nodes
bne = graph.batch_num_edges
pttns = gi.disjoint_partition(graph._graph, utils.toindex(bnn))
# split the frames
node_frames = [FrameRef(Frame(num_rows=n)) for n in bnn]
edge_frames = [FrameRef(Frame(num_rows=n)) for n in bne]
for attr, col in graph._node_frame.items():
col_splits = F.split(col, bnn, dim=0)
for i in range(bsize):
node_frames[i][attr] = col_splits[i]
for attr, col in graph._edge_frame.items():
col_splits = F.split(col, bne, dim=0)
for i in range(bsize):
edge_frames[i][attr] = col_splits[i]
return [DGLGraph(graph_data=pttns[i],
node_frame=node_frames[i],
edge_frame=edge_frames[i]) for i in range(bsize)]
def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
"""Batch a collection of :class:`~dgl.DGLGraph` and return a
:class:`BatchedDGLGraph` object that is independent of the :attr:`graph_list`.
Parameters
----------
graph_list : iterable
A collection of :class:`~dgl.DGLGraph` to be batched.
node_attrs : None, str or iterable
The node attributes to be batched. If ``None``, the :class:`BatchedDGLGraph`
object will not have any node attributes. By default, all node attributes will
be batched. If ``str`` or iterable, this should specify exactly what node
attributes to be batched.
edge_attrs : None, str or iterable, optional
Same as for the case of :attr:`node_attrs`
Returns
-------
BatchedDGLGraph
One single batched graph
See Also
--------
BatchedDGLGraph
unbatch
"""
return BatchedDGLGraph(graph_list, node_attrs, edge_attrs)
READOUT_ON_ATTRS = { READOUT_ON_ATTRS = {
'nodes': ('ndata', 'batch_num_nodes', 'number_of_nodes'), 'nodes': ('ndata', 'batch_num_nodes', 'number_of_nodes'),
'edges': ('edata', 'batch_num_edges', 'number_of_edges'), 'edges': ('edata', 'batch_num_edges', 'number_of_edges'),
...@@ -387,15 +43,12 @@ def _sum_on(graph, typestr, feat, weight): ...@@ -387,15 +43,12 @@ def _sum_on(graph, typestr, feat, weight):
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(feat) - 1)) weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(feat) - 1))
feat = weight * feat feat = weight * feat
if isinstance(graph, BatchedDGLGraph): n_graphs = graph.batch_size
n_graphs = graph.batch_size batch_num_objs = getattr(graph, batch_num_objs_attr)
batch_num_objs = getattr(graph, batch_num_objs_attr) seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs)) seg_id = F.copy_to(seg_id, F.context(feat))
seg_id = F.copy_to(seg_id, F.context(feat)) y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0)
y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0) return y
return y
else:
return F.sum(feat, 0)
def sum_nodes(graph, feat, weight=None): def sum_nodes(graph, feat, weight=None):
"""Sums all the values of node field :attr:`feat` in :attr:`graph`, optionally """Sums all the values of node field :attr:`feat` in :attr:`graph`, optionally
...@@ -420,10 +73,10 @@ def sum_nodes(graph, feat, weight=None): ...@@ -420,10 +73,10 @@ def sum_nodes(graph, feat, weight=None):
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is Return a stacked tensor with an extra first dimension whose size equals
returned instead, i.e. having an extra first dimension. batch size of the input graph.
Each row of the stacked tensor contains the readout result of the The i-th row of the stacked tensor contains the readout result of the
corresponding example in the batch. If an example has no nodes, i-th graph in the batched graph. If a graph has no nodes,
a zero tensor with the same shape is returned at the corresponding row. a zero tensor with the same shape is returned at the corresponding row.
Examples Examples
...@@ -456,7 +109,7 @@ def sum_nodes(graph, feat, weight=None): ...@@ -456,7 +109,7 @@ def sum_nodes(graph, feat, weight=None):
for a single graph. for a single graph.
>>> dgl.sum_nodes(g1, 'h', 'w') >>> dgl.sum_nodes(g1, 'h', 'w')
tensor([15.]) # 1 * 3 + 2 * 6 tensor([[15.]]) # 1 * 3 + 2 * 6
See Also See Also
-------- --------
...@@ -489,10 +142,10 @@ def sum_edges(graph, feat, weight=None): ...@@ -489,10 +142,10 @@ def sum_edges(graph, feat, weight=None):
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is Return a stacked tensor with an extra first dimension whose size equals
returned instead, i.e. having an extra first dimension. batch size of the input graph.
Each row of the stacked tensor contains the readout result of the The i-th row of the stacked tensor contains the readout result of the
corresponding example in the batch. If an example has no edges, i-th graph in the batched graph. If a graph has no edges,
a zero tensor with the same shape is returned at the corresponding row. a zero tensor with the same shape is returned at the corresponding row.
Examples Examples
...@@ -527,7 +180,7 @@ def sum_edges(graph, feat, weight=None): ...@@ -527,7 +180,7 @@ def sum_edges(graph, feat, weight=None):
for a single graph. for a single graph.
>>> dgl.sum_edges(g1, 'h', 'w') >>> dgl.sum_edges(g1, 'h', 'w')
tensor([15.]) # 1 * 3 + 2 * 6 tensor([[15.]]) # 1 * 3 + 2 * 6
See Also See Also
-------- --------
...@@ -566,24 +219,17 @@ def _mean_on(graph, typestr, feat, weight): ...@@ -566,24 +219,17 @@ def _mean_on(graph, typestr, feat, weight):
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(feat) - 1)) weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(feat) - 1))
feat = weight * feat feat = weight * feat
if isinstance(graph, BatchedDGLGraph): n_graphs = graph.batch_size
n_graphs = graph.batch_size batch_num_objs = getattr(graph, batch_num_objs_attr)
batch_num_objs = getattr(graph, batch_num_objs_attr) seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs)) seg_id = F.copy_to(seg_id, F.context(feat))
seg_id = F.copy_to(seg_id, F.context(feat)) if weight is not None:
if weight is not None: w = F.unsorted_1d_segment_sum(weight, seg_id, n_graphs, 0)
w = F.unsorted_1d_segment_sum(weight, seg_id, n_graphs, 0) y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0)
y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0) y = y / w
y = y / w
else:
y = F.unsorted_1d_segment_mean(feat, seg_id, n_graphs, 0)
return y
else: else:
if weight is None: y = F.unsorted_1d_segment_mean(feat, seg_id, n_graphs, 0)
return F.mean(feat, 0) return y
else:
y = F.sum(feat, 0) / F.sum(weight, 0)
return y
def mean_nodes(graph, feat, weight=None): def mean_nodes(graph, feat, weight=None):
"""Averages all the values of node field :attr:`feat` in :attr:`graph`, """Averages all the values of node field :attr:`feat` in :attr:`graph`,
...@@ -591,7 +237,7 @@ def mean_nodes(graph, feat, weight=None): ...@@ -591,7 +237,7 @@ def mean_nodes(graph, feat, weight=None):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : str feat : str
The feature field. The feature field.
...@@ -608,10 +254,10 @@ def mean_nodes(graph, feat, weight=None): ...@@ -608,10 +254,10 @@ def mean_nodes(graph, feat, weight=None):
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is Return a stacked tensor with an extra first dimension whose size equals
returned instead, i.e. having an extra first dimension. batch size of the input graph.
Each row of the stacked tensor contains the readout result of The i-th row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no nodes, the i-th graph in the batch. If a graph has no nodes,
a zero tensor with the same shape is returned at the corresponding row. a zero tensor with the same shape is returned at the corresponding row.
Examples Examples
...@@ -644,7 +290,7 @@ def mean_nodes(graph, feat, weight=None): ...@@ -644,7 +290,7 @@ def mean_nodes(graph, feat, weight=None):
for a single graph. for a single graph.
>>> dgl.mean_nodes(g1, 'h', 'w') # h1 * (w1 / (w1 + w2)) + h2 * (w2 / (w1 + w2)) >>> dgl.mean_nodes(g1, 'h', 'w') # h1 * (w1 / (w1 + w2)) + h2 * (w2 / (w1 + w2))
tensor([1.6667]) # 1 * (3 / (3 + 6)) + 2 * (6 / (3 + 6)) tensor([[1.6667]]) # 1 * (3 / (3 + 6)) + 2 * (6 / (3 + 6))
See Also See Also
-------- --------
...@@ -677,10 +323,10 @@ def mean_edges(graph, feat, weight=None): ...@@ -677,10 +323,10 @@ def mean_edges(graph, feat, weight=None):
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is Return a stacked tensor with an extra first dimension whose size equals
returned instead, i.e. having an extra first dimension. batch size of the input graph.
Each row of the stacked tensor contains the readout result of The i-th row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges, the i-th graph in the batched graph. If a graph has no edges,
a zero tensor with the same shape is returned at the corresponding row. a zero tensor with the same shape is returned at the corresponding row.
Examples Examples
...@@ -715,7 +361,7 @@ def mean_edges(graph, feat, weight=None): ...@@ -715,7 +361,7 @@ def mean_edges(graph, feat, weight=None):
for a single graph. for a single graph.
>>> dgl.mean_edges(g1, 'h', 'w') # h1 * (w1 / (w1 + w2)) + h2 * (w2 / (w1 + w2)) >>> dgl.mean_edges(g1, 'h', 'w') # h1 * (w1 / (w1 + w2)) + h2 * (w2 / (w1 + w2))
tensor([1.6667]) # 1 * (3 / (3 + 6)) + 2 * (6 / (3 + 6)) tensor([[1.6667]]) # 1 * (3 / (3 + 6)) + 2 * (6 / (3 + 6))
See Also See Also
-------- --------
...@@ -750,12 +396,10 @@ def _max_on(graph, typestr, feat): ...@@ -750,12 +396,10 @@ def _max_on(graph, typestr, feat):
# TODO: the current solution pads the different graph sizes to the same, # TODO: the current solution pads the different graph sizes to the same,
# a more efficient way is to use segment max, we need to implement it in # a more efficient way is to use segment max, we need to implement it in
# the future. # the future.
if isinstance(graph, BatchedDGLGraph): batch_num_objs = getattr(graph, batch_num_objs_attr)
batch_num_objs = getattr(graph, batch_num_objs_attr) feat = F.pad_packed_tensor(feat, batch_num_objs, -float('inf'))
feat = F.pad_packed_tensor(feat, batch_num_objs, -float('inf')) return F.max(feat, 1)
return F.max(feat, 1)
else:
return F.max(feat, 0)
def _softmax_on(graph, typestr, feat): def _softmax_on(graph, typestr, feat):
"""Internal function of applying batch-wise graph-level softmax """Internal function of applying batch-wise graph-level softmax
...@@ -782,13 +426,10 @@ def _softmax_on(graph, typestr, feat): ...@@ -782,13 +426,10 @@ def _softmax_on(graph, typestr, feat):
# TODO: the current solution pads the different graph sizes to the same, # TODO: the current solution pads the different graph sizes to the same,
# a more efficient way is to use segment sum/max, we need to implement # a more efficient way is to use segment sum/max, we need to implement
# it in the future. # it in the future.
if isinstance(graph, BatchedDGLGraph): batch_num_objs = getattr(graph, batch_num_objs_attr)
batch_num_objs = getattr(graph, batch_num_objs_attr) feat = F.pad_packed_tensor(feat, batch_num_objs, -float('inf'))
feat = F.pad_packed_tensor(feat, batch_num_objs, -float('inf')) feat = F.softmax(feat, 1)
feat = F.softmax(feat, 1) return F.pack_padded_tensor(feat, batch_num_objs)
return F.pack_padded_tensor(feat, batch_num_objs)
else:
return F.softmax(feat, 0)
def _broadcast_on(graph, typestr, feat_data): def _broadcast_on(graph, typestr, feat_data):
"""Internal function of broadcasting features to all nodes/edges. """Internal function of broadcasting features to all nodes/edges.
...@@ -808,21 +449,15 @@ def _broadcast_on(graph, typestr, feat_data): ...@@ -808,21 +449,15 @@ def _broadcast_on(graph, typestr, feat_data):
tensor tensor
The node/edge features tensor with shape :math:`(N, *)`. The node/edge features tensor with shape :math:`(N, *)`.
""" """
_, batch_num_objs_attr, num_objs_attr = READOUT_ON_ATTRS[typestr] _, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
if isinstance(graph, BatchedDGLGraph): batch_num_objs = getattr(graph, batch_num_objs_attr)
batch_num_objs = getattr(graph, batch_num_objs_attr) index = []
index = [] for i, num_obj in enumerate(batch_num_objs):
for i, num_obj in enumerate(batch_num_objs): index.extend([i] * num_obj)
index.extend([i] * num_obj) ctx = F.context(feat_data)
ctx = F.context(feat_data) index = F.copy_to(F.tensor(index), ctx)
index = F.copy_to(F.tensor(index), ctx) return F.gather_row(feat_data, index)
return F.gather_row(feat_data, index)
else:
num_objs = getattr(graph, num_objs_attr)()
if F.ndim(feat_data) == 1:
feat_data = F.unsqueeze(feat_data, 0)
return F.cat([feat_data] * num_objs, 0)
def _topk_on(graph, typestr, feat, k, descending=True, idx=None): def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
"""Internal function to take graph-wise top-k node/edge features of """Internal function to take graph-wise top-k node/edge features of
...@@ -832,7 +467,7 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None): ...@@ -832,7 +467,7 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
If idx is set to None, the function would return top-k value of all If idx is set to None, the function would return top-k value of all
indices, which is equivalent to calling `th.topk(graph.ndata[feat], dim=0)` indices, which is equivalent to calling `th.topk(graph.ndata[feat], dim=0)`
for each example of the input graph. for each single graph of the input batched-graph.
Parameters Parameters
--------- ---------
...@@ -854,14 +489,15 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None): ...@@ -854,14 +489,15 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
Returns Returns
------- -------
tuple of tensors: tuple of tensors:
The first tensor returns top-k features of the given graph with The first tensor returns top-k features of each single graph of
shape :math:`(K, D)`, if the input graph is a BatchedDGLGraph, the input graph:
a tensor with shape :math:`(B, K, D)` would be returned, where a tensor with shape :math:`(B, K, D)` would be returned, where
:math:`B` is the batch size. :math:`B` is the batch size of the input graph.
The second tensor returns the top-k indices of the given graph The second tensor returns the top-k indices of each single graph
with shape :math:`(K)`, if the input graph is a BatchedDGLGraph, of the input graph:
a tensor with shape :math:`(B, K)` would be returned, where a tensor with shape :math:`(B, K)`(:math:`(B, K, D)` if` idx
:math:`B` is the batch size. is set to None) would be returned, where
:math:`B` is the batch size of the input graph.
Notes Notes
----- -----
...@@ -870,7 +506,7 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None): ...@@ -870,7 +506,7 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
with all zero; in the second returned tensor, the behavior of :math:`n+1` with all zero; in the second returned tensor, the behavior of :math:`n+1`
to :math:`k`th elements is not defined. to :math:`k`th elements is not defined.
""" """
data_attr, batch_num_objs_attr, num_objs_attr = READOUT_ON_ATTRS[typestr] data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
data = getattr(graph, data_attr) data = getattr(graph, data_attr)
if F.ndim(data[feat]) > 2: if F.ndim(data[feat]) > 2:
raise DGLError('The {} feature `{}` should have dimension less than or' raise DGLError('The {} feature `{}` should have dimension less than or'
...@@ -878,12 +514,8 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None): ...@@ -878,12 +514,8 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
feat = data[feat] feat = data[feat]
hidden_size = F.shape(feat)[-1] hidden_size = F.shape(feat)[-1]
if isinstance(graph, BatchedDGLGraph): batch_num_objs = getattr(graph, batch_num_objs_attr)
batch_num_objs = getattr(graph, batch_num_objs_attr) batch_size = len(batch_num_objs)
batch_size = len(batch_num_objs)
else:
batch_num_objs = [getattr(graph, num_objs_attr)()]
batch_size = 1
length = max(max(batch_num_objs), k) length = max(max(batch_num_objs), k)
fill_val = -float('inf') if descending else float('inf') fill_val = -float('inf') if descending else float('inf')
...@@ -912,12 +544,8 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None): ...@@ -912,12 +544,8 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
shift = F.copy_to(shift, F.context(feat)) shift = F.copy_to(shift, F.context(feat))
topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift
if isinstance(graph, BatchedDGLGraph): return F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1)),\
return F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1)),\ topk_indices
topk_indices
else:
return F.reshape(F.gather_row(feat_, topk_indices_), (k, -1)),\
topk_indices
def max_nodes(graph, feat): def max_nodes(graph, feat):
...@@ -926,7 +554,7 @@ def max_nodes(graph, feat): ...@@ -926,7 +554,7 @@ def max_nodes(graph, feat):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : str feat : str
The feature field. The feature field.
...@@ -963,14 +591,14 @@ def max_nodes(graph, feat): ...@@ -963,14 +591,14 @@ def max_nodes(graph, feat):
Max over node attribute :attr:`h` in a single graph. Max over node attribute :attr:`h` in a single graph.
>>> dgl.max_nodes(g1, 'h') >>> dgl.max_nodes(g1, 'h')
tensor([2.]) tensor([[2.]])
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is Return a stacked tensor with an extra first dimension whose size equals
returned instead, i.e. having an extra first dimension. batch size of the input graph.
Each row of the stacked tensor contains the readout result of The i-th row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no nodes, the i-th graph in the batched graph. If a graph has no nodes,
a tensor filed with -inf of the same shape is returned at the a tensor filed with -inf of the same shape is returned at the
corresponding row. corresponding row.
""" """
...@@ -982,7 +610,7 @@ def max_edges(graph, feat): ...@@ -982,7 +610,7 @@ def max_edges(graph, feat):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : str feat : str
The feature field. The feature field.
...@@ -1021,14 +649,14 @@ def max_edges(graph, feat): ...@@ -1021,14 +649,14 @@ def max_edges(graph, feat):
Max over edge attribute :attr:`h` in a single graph. Max over edge attribute :attr:`h` in a single graph.
>>> dgl.max_edges(g1, 'h') >>> dgl.max_edges(g1, 'h')
tensor([2.]) tensor([[2.]])
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is Return a stacked tensor with an extra first dimension whose size equals
returned instead, i.e. having an extra first dimension. batch size of the input graph.
Each row of the stacked tensor contains the readout result of The i-th row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges, the i-th graph in the batched graph. If a graph has no edges,
a tensor filled with -inf of the same shape is returned at the a tensor filled with -inf of the same shape is returned at the
corresponding row. corresponding row.
""" """
...@@ -1040,7 +668,7 @@ def softmax_nodes(graph, feat): ...@@ -1040,7 +668,7 @@ def softmax_nodes(graph, feat):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : str feat : str
The feature field. The feature field.
...@@ -1085,8 +713,8 @@ def softmax_nodes(graph, feat): ...@@ -1085,8 +713,8 @@ def softmax_nodes(graph, feat):
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, the softmax is applied at If the input graph has batch size greater then one, the softmax is applied at
each example in the batch. each single graph in the batched graph.
""" """
return _softmax_on(graph, 'nodes', feat) return _softmax_on(graph, 'nodes', feat)
...@@ -1097,7 +725,7 @@ def softmax_edges(graph, feat): ...@@ -1097,7 +725,7 @@ def softmax_edges(graph, feat):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : str feat : str
The feature field. The feature field.
...@@ -1144,7 +772,7 @@ def softmax_edges(graph, feat): ...@@ -1144,7 +772,7 @@ def softmax_edges(graph, feat):
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, the softmax is applied at each If the input graph has batch size greater then one, the softmax is applied at each
example in the batch. example in the batch.
""" """
return _softmax_on(graph, 'edges', feat) return _softmax_on(graph, 'edges', feat)
...@@ -1155,7 +783,7 @@ def broadcast_nodes(graph, feat_data): ...@@ -1155,7 +783,7 @@ def broadcast_nodes(graph, feat_data):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatcheDGLGraph graph : DGLGraph
The graph. The graph.
feat_data : tensor feat_data : tensor
The feature to broadcast. Tensor shape is :math:`(*)` for single graph, and The feature to broadcast. Tensor shape is :math:`(*)` for single graph, and
...@@ -1205,8 +833,7 @@ def broadcast_nodes(graph, feat_data): ...@@ -1205,8 +833,7 @@ def broadcast_nodes(graph, feat_data):
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, feat[i] is broadcast to the nodes feat[i] is broadcast to the nodes in i-th graph in the batched graph.
in i-th example in the batch.
""" """
return _broadcast_on(graph, 'nodes', feat_data) return _broadcast_on(graph, 'nodes', feat_data)
...@@ -1216,7 +843,7 @@ def broadcast_edges(graph, feat_data): ...@@ -1216,7 +843,7 @@ def broadcast_edges(graph, feat_data):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat_data : tensor feat_data : tensor
The feature to broadcast. Tensor shape is :math:`(*)` for single The feature to broadcast. Tensor shape is :math:`(*)` for single
...@@ -1268,8 +895,7 @@ def broadcast_edges(graph, feat_data): ...@@ -1268,8 +895,7 @@ def broadcast_edges(graph, feat_data):
Notes Notes
----- -----
If graph is a :class:`BatchedDGLGraph` object, feat[i] is broadcast to feat[i] is broadcast to the edges in i-th graph in the batched graph.
the edges in i-th example in the batch.
""" """
return _broadcast_on(graph, 'edges', feat_data) return _broadcast_on(graph, 'edges', feat_data)
...@@ -1285,7 +911,7 @@ def topk_nodes(graph, feat, k, descending=True, idx=None): ...@@ -1285,7 +911,7 @@ def topk_nodes(graph, feat, k, descending=True, idx=None):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : str feat : str
The feature field. The feature field.
...@@ -1300,15 +926,15 @@ def topk_nodes(graph, feat, k, descending=True, idx=None): ...@@ -1300,15 +926,15 @@ def topk_nodes(graph, feat, k, descending=True, idx=None):
Returns Returns
------- -------
tuple of tensors tuple of tensors
The first tensor returns top-k node features of the given graph The first tensor returns top-k node features of each single graph of
with shape :math:`(K, D)`, if the input graph is a BatchedDGLGraph, the input graph:
a tensor with shape :math:`(B, K, D)` would be returned, where a tensor with shape :math:`(B, K, D)` would be returned, where
:math:`B` is the batch size. :math:`B` is the batch size of the input graph.
The second tensor returns the top-k edge indices of the given The second tensor returns the top-k node indices of each single graph
graph with shape :math:`(K)`(:math:`(K, D)` if idx is set to None), of the input graph:
if the input graph is a BatchedDGLGraph, a tensor with shape a tensor with shape :math:`(B, K)`(:math:`(B, K, D)` if` idx
:math:`(B, K)`(:math:`(B, K, D)` if` idx is set to None) would be is set to None) would be returned, where
returned, where :math:`B` is the batch size. :math:`B` is the batch size of the input graph.
Examples Examples
-------- --------
...@@ -1372,9 +998,9 @@ def topk_nodes(graph, feat, k, descending=True, idx=None): ...@@ -1372,9 +998,9 @@ def topk_nodes(graph, feat, k, descending=True, idx=None):
Top-k over node attribute :attr:`h` in a single graph. Top-k over node attribute :attr:`h` in a single graph.
>>> dgl.topk_nodes(g1, 'h', 3) >>> dgl.topk_nodes(g1, 'h', 3)
(tensor([[0.5901, 0.8307, 0.9280, 0.8954, 0.7997], (tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],
[0.5171, 0.6515, 0.9140, 0.7507, 0.5297], [0.5171, 0.6515, 0.9140, 0.7507, 0.5297],
[0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]), tensor([[[1, 0, 1, 3, 1], [0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]]), tensor([[[1, 0, 1, 3, 1],
[3, 2, 0, 2, 2], [3, 2, 0, 2, 2],
[2, 3, 2, 1, 3]]])) [2, 3, 2, 1, 3]]]))
...@@ -1400,7 +1026,7 @@ def topk_edges(graph, feat, k, descending=True, idx=None): ...@@ -1400,7 +1026,7 @@ def topk_edges(graph, feat, k, descending=True, idx=None):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : str feat : str
The feature field. The feature field.
...@@ -1415,15 +1041,15 @@ def topk_edges(graph, feat, k, descending=True, idx=None): ...@@ -1415,15 +1041,15 @@ def topk_edges(graph, feat, k, descending=True, idx=None):
Returns Returns
------- -------
tuple of tensors tuple of tensors
The first tensor returns top-k edge features of the given graph The first tensor returns top-k edge features of each single graph of
with shape :math:`(K, D)`, if the input graph is a BatchedDGLGraph, the input graph:
a tensor with shape :math:`(B, K, D)` would be returned, where a tensor with shape :math:`(B, K, D)` would be returned, where
:math:`B` is the batch size. :math:`B` is the batch size of the input graph.
The second tensor returns the top-k edge indices of the given The second tensor returns the top-k edge indices of each single graph
graph with shape :math:`(K)`(:math:`(K, D)` if idx is set to None), of the input graph:
if the input graph is a BatchedDGLGraph, a tensor with shape a tensor with shape :math:`(B, K)`(:math:`(B, K, D)` if` idx
:math:`(B, K)`(:math:`(B, K, D)` if` idx is set to None) would be is set to None) would be returned, where
returned, where :math:`B` is the batch size. :math:`B` is the batch size of the input graph.
Examples Examples
-------- --------
...@@ -1489,9 +1115,9 @@ def topk_edges(graph, feat, k, descending=True, idx=None): ...@@ -1489,9 +1115,9 @@ def topk_edges(graph, feat, k, descending=True, idx=None):
Top-k over edge attribute :attr:`h` in a single graph. Top-k over edge attribute :attr:`h` in a single graph.
>>> dgl.topk_edges(g1, 'h', 3) >>> dgl.topk_edges(g1, 'h', 3)
(tensor([[0.5901, 0.8307, 0.9280, 0.8954, 0.7997], (tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],
[0.5171, 0.6515, 0.9140, 0.7507, 0.5297], [0.5171, 0.6515, 0.9140, 0.7507, 0.5297],
[0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]), tensor([[[1, 0, 1, 3, 1], [0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]]), tensor([[[1, 0, 1, 3, 1],
[3, 2, 0, 2, 2], [3, 2, 0, 2, 2],
[2, 3, 2, 1, 3]]])) [2, 3, 2, 1, 3]]]))
......
"""Class for subgraph data structure."""
from __future__ import absolute_import
from .frame import Frame, FrameRef
from .graph import DGLGraph
from . import utils
from .base import DGLError
from .graph_index import map_to_subgraph_nid
class DGLSubGraph(DGLGraph):
"""The subgraph class.
There are two subgraph modes: shared and non-shared.
For the "non-shared" mode, the user needs to explicitly call
``copy_from_parent`` to copy node/edge features from its parent graph.
* If the user tries to get node/edge features before ``copy_from_parent``,
s/he will get nothing.
* If the subgraph already has its own node/edge features, ``copy_from_parent``
will override them.
* Any update on the subgraph's node/edge features will not be seen
by the parent graph. As such, the memory consumption is of the order
of the subgraph size.
* To write the subgraph's node/edge features back to parent graph. There are two options:
(1) Use ``copy_to_parent`` API to write node/edge features back.
(2) [TODO] Use ``dgl.merge`` to merge multiple subgraphs back to one parent.
The "shared" mode is currently not supported.
The subgraph is read-only on structure; graph mutation is not allowed.
Parameters
----------
parent : DGLGraph
The parent graph
sgi : SubgraphIndex
Internal subgraph data structure.
shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph.
"""
def __init__(self, parent, sgi, shared=False):
super(DGLSubGraph, self).__init__(graph_data=sgi.graph,
readonly=True)
if shared:
raise DGLError('Shared mode is not yet supported.')
self._parent = parent
self._parent_nid = sgi.induced_nodes
self._parent_eid = sgi.induced_edges
self._subgraph_index = sgi
# override APIs
def add_nodes(self, num, data=None):
"""Add nodes. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, data=None):
"""Add one edge. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, data=None):
"""Add many edges. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
@property
def parent_nid(self):
"""Get the parent node ids.
The returned tensor can be used as a map from the node id
in this subgraph to the node id in the parent graph.
Returns
-------
Tensor
The parent node id array.
"""
return self._parent_nid.tousertensor()
def _get_parent_eid(self):
# The parent eid might be lazily evaluated and thus may not
# be an index. Instead, it's a lambda function that returns
# an index.
if isinstance(self._parent_eid, utils.Index):
return self._parent_eid
else:
return self._parent_eid()
@property
def parent_eid(self):
"""Get the parent edge ids.
The returned tensor can be used as a map from the edge id
in this subgraph to the edge id in the parent graph.
Returns
-------
Tensor
The parent edge id array.
"""
return self._get_parent_eid().tousertensor()
def copy_to_parent(self, inplace=False):
"""Write node/edge features to the parent graph.
Parameters
----------
inplace : bool
If true, use inplace write (no gradient but faster)
"""
self._parent._node_frame.update_rows(
self._parent_nid, self._node_frame, inplace=inplace)
if self._parent._edge_frame.num_rows != 0:
self._parent._edge_frame.update_rows(
self._get_parent_eid(), self._edge_frame, inplace=inplace)
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
All old features will be removed.
"""
if self._parent._node_frame.num_rows != 0 and self._parent._node_frame.num_columns != 0:
self._node_frame = FrameRef(Frame(
self._parent._node_frame[self._parent_nid]))
if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0:
self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[self._get_parent_eid()]))
def map_to_subgraph_nid(self, parent_vids):
"""Map the node Ids in the parent graph to the node Ids in the subgraph.
Parameters
----------
parent_vids : list, tensor
The node ID array in the parent graph.
Returns
-------
tensor
The node ID array in the subgraph.
"""
v = map_to_subgraph_nid(self._subgraph_index, utils.toindex(parent_vids))
return v.tousertensor()
...@@ -7,12 +7,11 @@ from ._ffi.function import _init_api ...@@ -7,12 +7,11 @@ from ._ffi.function import _init_api
from .graph import DGLGraph from .graph import DGLGraph
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from . import ndarray as nd from . import ndarray as nd
from .subgraph import DGLSubGraph
from . import backend as F from . import backend as F
from .graph_index import from_coo from .graph_index import from_coo
from .graph_index import _get_halo_subgraph_inner_node from .graph_index import _get_halo_subgraph_inner_node
from .graph_index import _get_halo_subgraph_inner_edge from .graph_index import _get_halo_subgraph_inner_edge
from .batched_graph import BatchedDGLGraph, unbatch from .graph import unbatch
from .convert import graph, bipartite from .convert import graph, bipartite
from . import utils from . import utils
from .base import EID, NID from .base import EID, NID
...@@ -250,7 +249,6 @@ def reverse(g, share_ndata=False, share_edata=False): ...@@ -250,7 +249,6 @@ def reverse(g, share_ndata=False, share_edata=False):
Notes Notes
----- -----
* This function does not support :class:`~dgl.BatchedDGLGraph` objects.
* We do not dynamically update the topology of a graph once that of its reverse changes. * We do not dynamically update the topology of a graph once that of its reverse changes.
This can be particularly problematic when the node/edge attrs are shared. For example, This can be particularly problematic when the node/edge attrs are shared. For example,
if the topology of both the original graph and its reverse get changed independently, if the topology of both the original graph and its reverse get changed independently,
...@@ -307,12 +305,12 @@ def reverse(g, share_ndata=False, share_edata=False): ...@@ -307,12 +305,12 @@ def reverse(g, share_ndata=False, share_edata=False):
[2.], [2.],
[3.]]) [3.]])
""" """
assert not isinstance(g, BatchedDGLGraph), \
'reverse is not supported for a BatchedDGLGraph object'
g_reversed = DGLGraph(multigraph=g.is_multigraph) g_reversed = DGLGraph(multigraph=g.is_multigraph)
g_reversed.add_nodes(g.number_of_nodes()) g_reversed.add_nodes(g.number_of_nodes())
g_edges = g.all_edges(order='eid') g_edges = g.all_edges(order='eid')
g_reversed.add_edges(g_edges[1], g_edges[0]) g_reversed.add_edges(g_edges[1], g_edges[0])
g_reversed._batch_num_nodes = g._batch_num_nodes
g_reversed._batch_num_edges = g._batch_num_edges
if share_ndata: if share_ndata:
g_reversed._node_frame = g._node_frame g_reversed._node_frame = g._node_frame
if share_edata: if share_edata:
...@@ -391,17 +389,14 @@ def laplacian_lambda_max(g): ...@@ -391,17 +389,14 @@ def laplacian_lambda_max(g):
Parameters Parameters
---------- ----------
g : DGLGraph or BatchedDGLGraph g : DGLGraph
The input graph, it should be an undirected graph. The input graph, it should be an undirected graph.
Returns Returns
------- -------
list : list :
* If the input g is a DGLGraph, the returned value would be Return a list, where the i-th item indicates the largest eigenvalue
a list with one element, indicating the largest eigenvalue of g. of i-th graph in g.
* If the input g is a BatchedDGLGraph, the returned value would
be a list, where the i-th item indicates the largest eigenvalue
of i-th graph in g.
Examples Examples
-------- --------
...@@ -413,11 +408,7 @@ def laplacian_lambda_max(g): ...@@ -413,11 +408,7 @@ def laplacian_lambda_max(g):
>>> dgl.laplacian_lambda_max(g) >>> dgl.laplacian_lambda_max(g)
[1.809016994374948] [1.809016994374948]
""" """
if isinstance(g, BatchedDGLGraph): g_arr = unbatch(g)
g_arr = unbatch(g)
else:
g_arr = [g]
rst = [] rst = []
for g_i in g_arr: for g_i in g_arr:
n = g_i.number_of_nodes() n = g_i.number_of_nodes()
...@@ -573,7 +564,7 @@ def partition_graph_with_halo(g, node_part, num_hops): ...@@ -573,7 +564,7 @@ def partition_graph_with_halo(g, node_part, num_hops):
for i, subg in enumerate(subgs): for i, subg in enumerate(subgs):
inner_node = _get_halo_subgraph_inner_node(subg) inner_node = _get_halo_subgraph_inner_node(subg)
inner_edge = _get_halo_subgraph_inner_edge(subg) inner_edge = _get_halo_subgraph_inner_edge(subg)
subg = DGLSubGraph(g, subg) subg = g._create_subgraph(subg, subg.induced_nodes, subg.induced_edges)
inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack()) inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack())
subg.ndata['inner_node'] = inner_node subg.ndata['inner_node'] = inner_node
inner_edge = F.zerocopy_from_dlpack(inner_edge.to_dlpack()) inner_edge = F.zerocopy_from_dlpack(inner_edge.to_dlpack())
......
...@@ -124,8 +124,8 @@ def test_node_subgraph(): ...@@ -124,8 +124,8 @@ def test_node_subgraph():
subig = ig.node_subgraph(utils.toindex(randv)) subig = ig.node_subgraph(utils.toindex(randv))
check_basics(subg.graph, subig.graph) check_basics(subg.graph, subig.graph)
check_graph_equal(subg.graph, subig.graph) check_graph_equal(subg.graph, subig.graph)
assert F.asnumpy(map_to_subgraph_nid(subg, utils.toindex(randv1[0:10])).tousertensor() assert F.asnumpy(map_to_subgraph_nid(subg.induced_nodes, utils.toindex(randv1[0:10])).tousertensor()
== map_to_subgraph_nid(subig, utils.toindex(randv1[0:10])).tousertensor()).sum(0).item() == 10 == map_to_subgraph_nid(subig.induced_nodes, utils.toindex(randv1[0:10])).tousertensor()).sum(0).item() == 10
# node_subgraphs # node_subgraphs
randvs = [] randvs = []
......
...@@ -195,7 +195,7 @@ def test_softmax_edges(): ...@@ -195,7 +195,7 @@ def test_softmax_edges():
def test_broadcast_nodes(): def test_broadcast_nodes():
# test#1: basic # test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10)) g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,)) feat0 = F.randn((1, 40))
ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0) ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0)
assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth) assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth)
...@@ -204,23 +204,23 @@ def test_broadcast_nodes(): ...@@ -204,23 +204,23 @@ def test_broadcast_nodes():
g2 = dgl.DGLGraph() g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12)) g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3]) bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,)) feat1 = F.randn((1, 40))
feat2 = F.randn((40,)) feat2 = F.randn((1, 40))
feat3 = F.randn((40,)) feat3 = F.randn((1, 40))
ground_truth = F.stack( ground_truth = F.cat(
[feat0] * g0.number_of_nodes() +\ [feat0] * g0.number_of_nodes() +\
[feat1] * g1.number_of_nodes() +\ [feat1] * g1.number_of_nodes() +\
[feat2] * g2.number_of_nodes() +\ [feat2] * g2.number_of_nodes() +\
[feat3] * g3.number_of_nodes(), 0 [feat3] * g3.number_of_nodes(), 0
) )
assert F.allclose(dgl.broadcast_nodes( assert F.allclose(dgl.broadcast_nodes(
bg, F.stack([feat0, feat1, feat2, feat3], 0) bg, F.cat([feat0, feat1, feat2, feat3], 0)
), ground_truth) ), ground_truth)
def test_broadcast_edges(): def test_broadcast_edges():
# test#1: basic # test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10)) g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,)) feat0 = F.randn((1, 40))
ground_truth = F.stack([feat0] * g0.number_of_edges(), 0) ground_truth = F.stack([feat0] * g0.number_of_edges(), 0)
assert F.allclose(dgl.broadcast_edges(g0, feat0), ground_truth) assert F.allclose(dgl.broadcast_edges(g0, feat0), ground_truth)
...@@ -229,17 +229,17 @@ def test_broadcast_edges(): ...@@ -229,17 +229,17 @@ def test_broadcast_edges():
g2 = dgl.DGLGraph() g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12)) g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3]) bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,)) feat1 = F.randn((1, 40))
feat2 = F.randn((40,)) feat2 = F.randn((1, 40))
feat3 = F.randn((40,)) feat3 = F.randn((1, 40))
ground_truth = F.stack( ground_truth = F.cat(
[feat0] * g0.number_of_edges() +\ [feat0] * g0.number_of_edges() +\
[feat1] * g1.number_of_edges() +\ [feat1] * g1.number_of_edges() +\
[feat2] * g2.number_of_edges() +\ [feat2] * g2.number_of_edges() +\
[feat3] * g3.number_of_edges(), 0 [feat3] * g3.number_of_edges(), 0
) )
assert F.allclose(dgl.broadcast_edges( assert F.allclose(dgl.broadcast_edges(
bg, F.stack([feat0, feat1, feat2, feat3], 0) bg, F.cat([feat0, feat1, feat2, feat3], 0)
), ground_truth) ), ground_truth)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -41,13 +41,13 @@ def test_basics(): ...@@ -41,13 +41,13 @@ def test_basics():
eid = {2, 3, 4, 5, 10, 11, 12, 13, 16} eid = {2, 3, 4, 5, 10, 11, 12, 13, 16}
assert set(F.zerocopy_to_numpy(sg.parent_eid)) == eid assert set(F.zerocopy_to_numpy(sg.parent_eid)) == eid
eid = F.tensor(sg.parent_eid) eid = F.tensor(sg.parent_eid)
# the subgraph is empty initially # the subgraph is empty initially except for NID/EID field
assert len(sg.ndata) == 0
assert len(sg.edata) == 0
# the data is copied after explict copy from
sg.copy_from_parent()
assert len(sg.ndata) == 1 assert len(sg.ndata) == 1
assert len(sg.edata) == 1 assert len(sg.edata) == 1
# the data is copied after explict copy from
sg.copy_from_parent()
assert len(sg.ndata) == 2
assert len(sg.edata) == 2
sh = sg.ndata['h'] sh = sg.ndata['h']
assert F.allclose(F.gather_row(h, F.tensor(nid)), sh) assert F.allclose(F.gather_row(h, F.tensor(nid)), sh)
''' '''
......
...@@ -328,7 +328,7 @@ def test_set2set(): ...@@ -328,7 +328,7 @@ def test_set2set():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(g, h0) h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g]) bg = dgl.batch([g, g, g])
...@@ -346,7 +346,7 @@ def test_glob_att_pool(): ...@@ -346,7 +346,7 @@ def test_glob_att_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0) h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g, g]) bg = dgl.batch([g, g, g, g])
...@@ -366,13 +366,13 @@ def test_simple_pool(): ...@@ -366,13 +366,13 @@ def test_simple_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(g, h0) h1 = sum_pool(g, h0)
check_close(h1, F.sum(h0, 0)) check_close(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0) h1 = avg_pool(g, h0)
check_close(h1, F.mean(h0, 0)) check_close(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0) h1 = max_pool(g, h0)
check_close(h1, F.max(h0, 0)) check_close(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0) h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
# test#2: batched graph # test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5)) g_ = dgl.DGLGraph(nx.path_graph(5))
......
...@@ -124,7 +124,7 @@ def test_set2set(): ...@@ -124,7 +124,7 @@ def test_set2set():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(g, h0) h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
# test#2: batched graph # test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(11)) g1 = dgl.DGLGraph(nx.path_graph(11))
...@@ -145,7 +145,7 @@ def test_glob_att_pool(): ...@@ -145,7 +145,7 @@ def test_glob_att_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0) h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g, g]) bg = dgl.batch([g, g, g, g])
...@@ -170,13 +170,13 @@ def test_simple_pool(): ...@@ -170,13 +170,13 @@ def test_simple_pool():
max_pool = max_pool.to(ctx) max_pool = max_pool.to(ctx)
sort_pool = sort_pool.to(ctx) sort_pool = sort_pool.to(ctx)
h1 = sum_pool(g, h0) h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0) h1 = avg_pool(g, h0)
assert F.allclose(h1, F.mean(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0) h1 = max_pool(g, h0)
assert F.allclose(h1, F.max(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0) h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.dim() == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
# test#2: batched graph # test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5)) g_ = dgl.DGLGraph(nx.path_graph(5))
...@@ -228,7 +228,7 @@ def test_set_trans(): ...@@ -228,7 +228,7 @@ def test_set_trans():
h1 = st_enc_1(g, h0) h1 = st_enc_1(g, h0)
assert h1.shape == h0.shape assert h1.shape == h0.shape
h2 = st_dec(g, h1) h2 = st_dec(g, h1)
assert h2.shape[0] == 200 and h2.dim() == 1 assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
# test#2: batched graph # test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5)) g1 = dgl.DGLGraph(nx.path_graph(5))
......
...@@ -93,13 +93,13 @@ def test_simple_pool(): ...@@ -93,13 +93,13 @@ def test_simple_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(g, h0) h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0) h1 = avg_pool(g, h0)
assert F.allclose(h1, F.mean(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0) h1 = max_pool(g, h0)
assert F.allclose(h1, F.max(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0) h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
# test#2: batched graph # test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5)) g_ = dgl.DGLGraph(nx.path_graph(5))
...@@ -246,7 +246,7 @@ def test_glob_att_pool(): ...@@ -246,7 +246,7 @@ def test_glob_att_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0) h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g, g]) bg = dgl.batch([g, g, g, g])
......
...@@ -91,12 +91,13 @@ def plot_tree(g): ...@@ -91,12 +91,13 @@ def plot_tree(g):
plot_tree(graph.to_networkx()) plot_tree(graph.to_networkx())
################################################################################# #################################################################################
# You can read more about the definition of :func:`~dgl.batched_graph.batch`, or # You can read more about the definition of :func:`~dgl.batch`, or
# skip ahead to the next step: # skip ahead to the next step:
# .. note:: # .. note::
# #
# **Definition**: A :class:`~dgl.batched_graph.BatchedDGLGraph` is a # **Definition**: :func:`~dgl.batch` unions a list of :math:`B`
# :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s. # :class:`~dgl.DGLGraph`\ s and returns a :class:`~dgl.DGLGraph` of batch
# size :math:`B`.
# #
# - The union includes all the nodes, # - The union includes all the nodes,
# edges, and their features. The order of nodes, edges, and features are # edges, and their features. The order of nodes, edges, and features are
...@@ -108,23 +109,16 @@ plot_tree(graph.to_networkx()) ...@@ -108,23 +109,16 @@ plot_tree(graph.to_networkx())
# :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph. # :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph.
# #
# - Therefore, performing feature transformation and message passing on # - Therefore, performing feature transformation and message passing on
# ``BatchedDGLGraph`` is equivalent to doing those # the batched graph is equivalent to doing those
# on all ``DGLGraph`` constituents in parallel. # on all ``DGLGraph`` constituents in parallel.
# #
# - Duplicate references to the same graph are # - Duplicate references to the same graph are
# treated as deep copies; the nodes, edges, and features are duplicated, # treated as deep copies; the nodes, edges, and features are duplicated,
# and mutation on one reference does not affect the other. # and mutation on one reference does not affect the other.
# - Currently, ``BatchedDGLGraph`` is immutable in # - The batched graph keeps track of the meta
# graph structure. You can't add
# nodes and edges to it. You need to support mutable batched graphs in
# (far) future.
# - The ``BatchedDGLGraph`` keeps track of the meta
# information of the constituents so it can be # information of the constituents so it can be
# :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s. # :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s.
# #
# For more details about the :class:`~dgl.batched_graph.BatchedDGLGraph`
# module in DGL, you can click the class name.
#
# Step 2: Tree-LSTM cell with message-passing APIs # Step 2: Tree-LSTM cell with message-passing APIs
# ------------------------------------------------ # ------------------------------------------------
# #
......
...@@ -798,9 +798,8 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) ...@@ -798,9 +798,8 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__ # <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__
# , and it is worth explaining one more time why this is so. # , and it is worth explaining one more time why this is so.
# #
# By batching many small graphs, DGL internally maintains a large *container* # By batching many small graphs, DGL parallels message passing on each individual
# graph (``BatchedDGLGraph``) over which ``update_all`` propels message-passing # graphs of a batch.
# on all the edges and nodes.
# #
# With ``dgl.batch``, you merge ``g_{1}, ..., g_{N}`` into one single giant # With ``dgl.batch``, you merge ``g_{1}, ..., g_{N}`` into one single giant
# graph consisting of :math:`N` isolated small graphs. For example, if we # graph consisting of :math:`N` isolated small graphs. For example, if we
...@@ -833,7 +832,7 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) ...@@ -833,7 +832,7 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# internally groups nodes with the same in-degrees and calls reduce UDF once # internally groups nodes with the same in-degrees and calls reduce UDF once
# for each group. Thus, batching also reduces number of calls to these UDFs. # for each group. Thus, batching also reduces number of calls to these UDFs.
# #
# The modification of the node/edge features of a ``BatchedDGLGraph`` object # The modification of the node/edge features of the batched graph object
# does not take effect on the features of the original small graphs, so we # does not take effect on the features of the original small graphs, so we
# need to replace the old graph list with the new graph list # need to replace the old graph list with the new graph list
# ``g_list = dgl.unbatch(bg)``. # ``g_list = dgl.unbatch(bg)``.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment