Unverified Commit 93ac29ce authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Unify DGLGraph, BatchedDGLGraph and DGLSubGraph (#1216)



* upd

* upd

* upd

* lint

* fix

* fix test

* fix

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd tutorial

* upd

* upd

* fix kg

* upd doc organization

* refresh test

* upd

* refactor doc

* fix lint
Co-authored-by: default avatarMinjie Wang <minjie.wang@nyu.edu>
parent 8874e830
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .gnn import GCNLayer, GATLayer from .gnn import GCNLayer, GATLayer
from ...batched_graph import BatchedDGLGraph, max_nodes from ...readout import max_nodes
from ...nn.pytorch import WeightAndSum from ...nn.pytorch import WeightAndSum
from ...contrib.deprecation import deprecated from ...contrib.deprecation import deprecated
...@@ -74,13 +74,13 @@ class BaseGNNClassifier(nn.Module): ...@@ -74,13 +74,13 @@ class BaseGNNClassifier(nn.Module):
self.soft_classifier = MLPBinaryClassifier( self.soft_classifier = MLPBinaryClassifier(
self.g_feats, classifier_hidden_feats, n_tasks, dropout) self.g_feats, classifier_hidden_feats, n_tasks, dropout)
def forward(self, bg, feats): def forward(self, g, feats):
"""Multi-task prediction for a batch of molecules """Multi-task prediction for a batch of molecules
Parameters Parameters
---------- ----------
bg : BatchedDGLGraph g : DGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M0) feats : FloatTensor of shape (N, M0)
Initial features for all atoms in the batch of molecules Initial features for all atoms in the batch of molecules
...@@ -91,18 +91,15 @@ class BaseGNNClassifier(nn.Module): ...@@ -91,18 +91,15 @@ class BaseGNNClassifier(nn.Module):
""" """
# Update atom features with GNNs # Update atom features with GNNs
for gnn in self.gnn_layers: for gnn in self.gnn_layers:
feats = gnn(bg, feats) feats = gnn(g, feats)
# Compute molecule features from atom features # Compute molecule features from atom features
h_g_sum = self.weighted_sum_readout(bg, feats) h_g_sum = self.weighted_sum_readout(g, feats)
with bg.local_scope(): with g.local_scope():
bg.ndata['h'] = feats g.ndata['h'] = feats
h_g_max = max_nodes(bg, 'h') h_g_max = max_nodes(g, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
h_g_max = h_g_max.unsqueeze(0)
h_g = torch.cat([h_g_sum, h_g_max], dim=1) h_g = torch.cat([h_g_sum, h_g_max], dim=1)
# Multi-task prediction # Multi-task prediction
......
...@@ -42,13 +42,13 @@ class GCNLayer(nn.Module): ...@@ -42,13 +42,13 @@ class GCNLayer(nn.Module):
if batchnorm: if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats) self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, bg, feats): def forward(self, g, feats):
"""Update atom representations """Update atom representations
Parameters Parameters
---------- ----------
bg : BatchedDGLGraph g : DGLGraph
Batched DGLGraphs for processing multiple molecules in parallel DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1) feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph * N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization * M1 is the input atom feature size, must match in_feats in initialization
...@@ -58,7 +58,7 @@ class GCNLayer(nn.Module): ...@@ -58,7 +58,7 @@ class GCNLayer(nn.Module):
new_feats : FloatTensor of shape (N, M2) new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size, must match out_feats in initialization * M2 is the output atom feature size, must match out_feats in initialization
""" """
new_feats = self.graph_conv(bg, feats) new_feats = self.graph_conv(g, feats)
if self.residual: if self.residual:
res_feats = self.activation(self.res_connection(feats)) res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats new_feats = new_feats + res_feats
......
...@@ -7,7 +7,7 @@ import torch.nn.functional as F ...@@ -7,7 +7,7 @@ import torch.nn.functional as F
import rdkit.Chem as Chem import rdkit.Chem as Chem
from ....batched_graph import batch, unbatch from ....graph import batch, unbatch
from ....contrib.deprecation import deprecated from ....contrib.deprecation import deprecated
from ....data.utils import get_download_dir from ....data.utils import get_download_dir
......
...@@ -63,7 +63,7 @@ class ChebConv(nn.Block): ...@@ -63,7 +63,7 @@ class ChebConv(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
......
...@@ -3,8 +3,7 @@ ...@@ -3,8 +3,7 @@
from mxnet import gluon, nd from mxnet import gluon, nd
from mxnet.gluon import nn from mxnet.gluon import nn
from ... import BatchedDGLGraph from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling', __all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
...@@ -24,7 +23,7 @@ class SumPooling(nn.Block): ...@@ -24,7 +23,7 @@ class SumPooling(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -33,9 +32,8 @@ class SumPooling(nn.Block): ...@@ -33,9 +32,8 @@ class SumPooling(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -61,7 +59,7 @@ class AvgPooling(nn.Block): ...@@ -61,7 +59,7 @@ class AvgPooling(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -70,9 +68,8 @@ class AvgPooling(nn.Block): ...@@ -70,9 +68,8 @@ class AvgPooling(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -98,7 +95,7 @@ class MaxPooling(nn.Block): ...@@ -98,7 +95,7 @@ class MaxPooling(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -107,9 +104,8 @@ class MaxPooling(nn.Block): ...@@ -107,9 +104,8 @@ class MaxPooling(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -139,7 +135,7 @@ class SortPooling(nn.Block): ...@@ -139,7 +135,7 @@ class SortPooling(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -148,9 +144,8 @@ class SortPooling(nn.Block): ...@@ -148,9 +144,8 @@ class SortPooling(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(k * D)` (if The output feature with shape :math:`(B, k * D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, k * D)`.
""" """
# Sort the feature of each node in ascending order. # Sort the feature of each node in ascending order.
with graph.local_scope(): with graph.local_scope():
...@@ -159,10 +154,7 @@ class SortPooling(nn.Block): ...@@ -159,10 +154,7 @@ class SortPooling(nn.Block):
# Sort nodes according to their last features. # Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k)[0].reshape( ret = topk_nodes(graph, 'h', self.k)[0].reshape(
-1, self.k * feat.shape[-1]) -1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph): return ret
return ret
else:
return ret.squeeze(axis=0)
def __repr__(self): def __repr__(self):
return 'SortPooling(k={})'.format(self.k) return 'SortPooling(k={})'.format(self.k)
...@@ -195,7 +187,7 @@ class GlobalAttentionPooling(nn.Block): ...@@ -195,7 +187,7 @@ class GlobalAttentionPooling(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -204,9 +196,8 @@ class GlobalAttentionPooling(nn.Block): ...@@ -204,9 +196,8 @@ class GlobalAttentionPooling(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
with graph.local_scope(): with graph.local_scope():
gate = self.gate_nn(feat) gate = self.gate_nn(feat)
...@@ -263,7 +254,7 @@ class Set2Set(nn.Block): ...@@ -263,7 +254,7 @@ class Set2Set(nn.Block):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -272,14 +263,11 @@ class Set2Set(nn.Block): ...@@ -272,14 +263,11 @@ class Set2Set(nn.Block):
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
with graph.local_scope(): with graph.local_scope():
batch_size = 1 batch_size = graph.batch_size
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
h = (nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context), h = (nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context),
nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context)) nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context))
...@@ -288,23 +276,14 @@ class Set2Set(nn.Block): ...@@ -288,23 +276,14 @@ class Set2Set(nn.Block):
for _ in range(self.n_iters): for _ in range(self.n_iters):
q, h = self.lstm(q_star.expand_dims(axis=0), h) q, h = self.lstm(q_star.expand_dims(axis=0), h)
q = q.reshape((batch_size, self.input_dim)) q = q.reshape((batch_size, self.input_dim))
e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True) e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True)
graph.ndata['e'] = e graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e') alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r') readout = sum_nodes(graph, 'r')
if readout.ndim == 1: # graph is not a BatchedDGLGraph
readout = readout.expand_dims(0)
q_star = nd.concat(q, readout, dim=-1) q_star = nd.concat(q, readout, dim=-1)
if isinstance(graph, BatchedDGLGraph): return q_star
return q_star
else:
return q_star.squeeze(axis=0)
def __repr__(self): def __repr__(self):
summary = 'Set2Set(' summary = 'Set2Set('
......
...@@ -203,7 +203,7 @@ class AtomicConv(nn.Module): ...@@ -203,7 +203,7 @@ class AtomicConv(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
Topology based on which message passing is performed. Topology based on which message passing is performed.
feat : Float32 tensor of shape (V, 1) feat : Float32 tensor of shape (V, 1)
Initial node features, which are atomic numbers in the paper. Initial node features, which are atomic numbers in the paper.
......
...@@ -68,7 +68,7 @@ class ChebConv(nn.Module): ...@@ -68,7 +68,7 @@ class ChebConv(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
......
...@@ -4,9 +4,8 @@ import torch as th ...@@ -4,9 +4,8 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from ... import BatchedDGLGraph
from ...backend import pytorch as F from ...backend import pytorch as F
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\ from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes softmax_nodes, topk_nodes
...@@ -29,7 +28,7 @@ class SumPooling(nn.Module): ...@@ -29,7 +28,7 @@ class SumPooling(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -38,9 +37,8 @@ class SumPooling(nn.Module): ...@@ -38,9 +37,8 @@ class SumPooling(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -62,7 +60,7 @@ class AvgPooling(nn.Module): ...@@ -62,7 +60,7 @@ class AvgPooling(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -71,9 +69,8 @@ class AvgPooling(nn.Module): ...@@ -71,9 +69,8 @@ class AvgPooling(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -95,7 +92,7 @@ class MaxPooling(nn.Module): ...@@ -95,7 +92,7 @@ class MaxPooling(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -104,9 +101,8 @@ class MaxPooling(nn.Module): ...@@ -104,9 +101,8 @@ class MaxPooling(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -132,7 +128,7 @@ class SortPooling(nn.Module): ...@@ -132,7 +128,7 @@ class SortPooling(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -141,9 +137,8 @@ class SortPooling(nn.Module): ...@@ -141,9 +137,8 @@ class SortPooling(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(k * D)` (if The output feature with shape :math:`(B, k * D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, k * D)`.
""" """
with graph.local_scope(): with graph.local_scope():
# Sort the feature of each node in ascending order. # Sort the feature of each node in ascending order.
...@@ -152,10 +147,7 @@ class SortPooling(nn.Module): ...@@ -152,10 +147,7 @@ class SortPooling(nn.Module):
# Sort nodes according to their last features. # Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k, idx=-1)[0].view( ret = topk_nodes(graph, 'h', self.k, idx=-1)[0].view(
-1, self.k * feat.shape[-1]) -1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph): return ret
return ret
else:
return ret.squeeze(0)
class GlobalAttentionPooling(nn.Module): class GlobalAttentionPooling(nn.Module):
...@@ -193,9 +185,8 @@ class GlobalAttentionPooling(nn.Module): ...@@ -193,9 +185,8 @@ class GlobalAttentionPooling(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
with graph.local_scope(): with graph.local_scope():
gate = self.gate_nn(feat) gate = self.gate_nn(feat)
...@@ -257,7 +248,7 @@ class Set2Set(nn.Module): ...@@ -257,7 +248,7 @@ class Set2Set(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -266,14 +257,11 @@ class Set2Set(nn.Module): ...@@ -266,14 +257,11 @@ class Set2Set(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
with graph.local_scope(): with graph.local_scope():
batch_size = 1 batch_size = graph.batch_size
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)), h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
feat.new_zeros((self.n_layers, batch_size, self.input_dim))) feat.new_zeros((self.n_layers, batch_size, self.input_dim)))
...@@ -283,23 +271,14 @@ class Set2Set(nn.Module): ...@@ -283,23 +271,14 @@ class Set2Set(nn.Module):
for _ in range(self.n_iters): for _ in range(self.n_iters):
q, h = self.lstm(q_star.unsqueeze(0), h) q, h = self.lstm(q_star.unsqueeze(0), h)
q = q.view(batch_size, self.input_dim) q = q.view(batch_size, self.input_dim)
e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True) e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True)
graph.ndata['e'] = e graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e') alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r') readout = sum_nodes(graph, 'r')
if readout.dim() == 1: # graph is not a BatchedDGLGraph
readout = readout.unsqueeze(0)
q_star = th.cat([q, readout], dim=-1) q_star = th.cat([q, readout], dim=-1)
if isinstance(graph, BatchedDGLGraph): return q_star
return q_star
else:
return q_star.squeeze(0)
def extra_repr(self): def extra_repr(self):
"""Set the extra representation of the module. """Set the extra representation of the module.
...@@ -574,7 +553,7 @@ class SetTransformerEncoder(nn.Module): ...@@ -574,7 +553,7 @@ class SetTransformerEncoder(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -585,14 +564,9 @@ class SetTransformerEncoder(nn.Module): ...@@ -585,14 +564,9 @@ class SetTransformerEncoder(nn.Module):
torch.Tensor torch.Tensor
The output feature with shape :math:`(N, D)`. The output feature with shape :math:`(N, D)`.
""" """
if isinstance(graph, BatchedDGLGraph): lengths = graph.batch_num_nodes
lengths = graph.batch_num_nodes
else:
lengths = [graph.number_of_nodes()]
for layer in self.layers: for layer in self.layers:
feat = layer(feat, lengths) feat = layer(feat, lengths)
return feat return feat
...@@ -640,7 +614,7 @@ class SetTransformerDecoder(nn.Module): ...@@ -640,7 +614,7 @@ class SetTransformerDecoder(nn.Module):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -649,25 +623,15 @@ class SetTransformerDecoder(nn.Module): ...@@ -649,25 +623,15 @@ class SetTransformerDecoder(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
if isinstance(graph, BatchedDGLGraph): len_pma = graph.batch_num_nodes
len_pma = graph.batch_num_nodes len_sab = [self.k] * graph.batch_size
len_sab = [self.k] * graph.batch_size
else:
len_pma = [graph.number_of_nodes()]
len_sab = [self.k]
feat = self.pma(feat, len_pma) feat = self.pma(feat, len_pma)
for layer in self.layers: for layer in self.layers:
feat = layer(feat, len_sab) feat = layer(feat, len_sab)
return feat.view(graph.batch_size, self.k * self.d_model)
if isinstance(graph, BatchedDGLGraph):
return feat.view(graph.batch_size, self.k * self.d_model)
else:
return feat.view(self.k * self.d_model)
class WeightAndSum(nn.Module): class WeightAndSum(nn.Module):
...@@ -686,13 +650,13 @@ class WeightAndSum(nn.Module): ...@@ -686,13 +650,13 @@ class WeightAndSum(nn.Module):
nn.Sigmoid() nn.Sigmoid()
) )
def forward(self, bg, feats): def forward(self, g, feats):
"""Compute molecule representations out of atom representations """Compute molecule representations out of atom representations
Parameters Parameters
---------- ----------
bg : BatchedDGLGraph g : DGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats) feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules Representations for all atoms in the molecules
* N is the total number of atoms in all molecules * N is the total number of atoms in all molecules
...@@ -702,9 +666,9 @@ class WeightAndSum(nn.Module): ...@@ -702,9 +666,9 @@ class WeightAndSum(nn.Module):
FloatTensor of shape (B, self.in_feats) FloatTensor of shape (B, self.in_feats)
Representations for B molecules Representations for B molecules
""" """
with bg.local_scope(): with g.local_scope():
bg.ndata['h'] = feats g.ndata['h'] = feats
bg.ndata['w'] = self.atom_weighting(bg.ndata['h']) g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(bg, 'h', 'w') h_g_sum = sum_nodes(g, 'h', 'w')
return h_g_sum return h_g_sum
...@@ -4,8 +4,7 @@ import tensorflow as tf ...@@ -4,8 +4,7 @@ import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
from ... import BatchedDGLGraph from ...readout import sum_nodes, mean_nodes, max_nodes, \
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, \
softmax_nodes, topk_nodes softmax_nodes, topk_nodes
...@@ -29,7 +28,7 @@ class SumPooling(layers.Layer): ...@@ -29,7 +28,7 @@ class SumPooling(layers.Layer):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -38,9 +37,8 @@ class SumPooling(layers.Layer): ...@@ -38,9 +37,8 @@ class SumPooling(layers.Layer):
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -63,7 +61,7 @@ class AvgPooling(layers.Layer): ...@@ -63,7 +61,7 @@ class AvgPooling(layers.Layer):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -72,9 +70,8 @@ class AvgPooling(layers.Layer): ...@@ -72,9 +70,8 @@ class AvgPooling(layers.Layer):
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -97,7 +94,7 @@ class MaxPooling(layers.Layer): ...@@ -97,7 +94,7 @@ class MaxPooling(layers.Layer):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
...@@ -106,9 +103,8 @@ class MaxPooling(layers.Layer): ...@@ -106,9 +103,8 @@ class MaxPooling(layers.Layer):
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature with shape :math:`(*)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, *)`.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata['h'] = feat
...@@ -135,7 +131,7 @@ class SortPooling(layers.Layer): ...@@ -135,7 +131,7 @@ class SortPooling(layers.Layer):
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
...@@ -144,9 +140,8 @@ class SortPooling(layers.Layer): ...@@ -144,9 +140,8 @@ class SortPooling(layers.Layer):
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature with shape :math:`(k * D)` (if The output feature with shape :math:`(B, k * D)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, k * D)`.
""" """
with graph.local_scope(): with graph.local_scope():
# Sort the feature of each node in ascending order. # Sort the feature of each node in ascending order.
...@@ -155,10 +150,7 @@ class SortPooling(layers.Layer): ...@@ -155,10 +150,7 @@ class SortPooling(layers.Layer):
# Sort nodes according to their last features. # Sort nodes according to their last features.
ret = tf.reshape(topk_nodes(graph, 'h', self.k, idx=-1)[0], ( ret = tf.reshape(topk_nodes(graph, 'h', self.k, idx=-1)[0], (
-1, self.k * feat.shape[-1])) -1, self.k * feat.shape[-1]))
if isinstance(graph, BatchedDGLGraph): return ret
return ret
else:
return tf.squeeze(ret, 0)
class GlobalAttentionPooling(layers.Layer): class GlobalAttentionPooling(layers.Layer):
...@@ -197,9 +189,8 @@ class GlobalAttentionPooling(layers.Layer): ...@@ -197,9 +189,8 @@ class GlobalAttentionPooling(layers.Layer):
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature with shape :math:`(D)` (if The output feature with shape :math:`(B, *)`, where
input graph is a BatchedDGLGraph, the result shape :math:`B` refers to the batch size.
would be :math:`(B, D)`.
""" """
with graph.local_scope(): with graph.local_scope():
gate = self.gate_nn(feat) gate = self.gate_nn(feat)
...@@ -234,13 +225,13 @@ class WeightAndSum(layers.Layer): ...@@ -234,13 +225,13 @@ class WeightAndSum(layers.Layer):
layers.Activation(tf.nn.sigmoid) layers.Activation(tf.nn.sigmoid)
) )
def call(self, bg, feats): def call(self, g, feats):
"""Compute molecule representations out of atom representations """Compute molecule representations out of atom representations
Parameters Parameters
---------- ----------
bg : BatchedDGLGraph g : DGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats) feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules Representations for all atoms in the molecules
* N is the total number of atoms in all molecules * N is the total number of atoms in all molecules
...@@ -250,9 +241,9 @@ class WeightAndSum(layers.Layer): ...@@ -250,9 +241,9 @@ class WeightAndSum(layers.Layer):
FloatTensor of shape (B, self.in_feats) FloatTensor of shape (B, self.in_feats)
Representations for B molecules Representations for B molecules
""" """
with bg.local_scope(): with g.local_scope():
bg.ndata['h'] = feats g.ndata['h'] = feats
bg.ndata['w'] = self.atom_weighting(bg.ndata['h']) g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(bg, 'h', 'w') h_g_sum = sum_nodes(g, 'h', 'w')
return h_g_sum return h_g_sum
"""Class for subgraph data structure."""
from __future__ import absolute_import
from .frame import Frame, FrameRef
from .graph import DGLGraph
from . import utils
from .base import DGLError
from .graph_index import map_to_subgraph_nid
class DGLSubGraph(DGLGraph):
"""The subgraph class.
There are two subgraph modes: shared and non-shared.
For the "non-shared" mode, the user needs to explicitly call
``copy_from_parent`` to copy node/edge features from its parent graph.
* If the user tries to get node/edge features before ``copy_from_parent``,
s/he will get nothing.
* If the subgraph already has its own node/edge features, ``copy_from_parent``
will override them.
* Any update on the subgraph's node/edge features will not be seen
by the parent graph. As such, the memory consumption is of the order
of the subgraph size.
* To write the subgraph's node/edge features back to parent graph. There are two options:
(1) Use ``copy_to_parent`` API to write node/edge features back.
(2) [TODO] Use ``dgl.merge`` to merge multiple subgraphs back to one parent.
The "shared" mode is currently not supported.
The subgraph is read-only on structure; graph mutation is not allowed.
Parameters
----------
parent : DGLGraph
The parent graph
sgi : SubgraphIndex
Internal subgraph data structure.
shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph.
"""
def __init__(self, parent, sgi, shared=False):
super(DGLSubGraph, self).__init__(graph_data=sgi.graph,
readonly=True)
if shared:
raise DGLError('Shared mode is not yet supported.')
self._parent = parent
self._parent_nid = sgi.induced_nodes
self._parent_eid = sgi.induced_edges
self._subgraph_index = sgi
# override APIs
def add_nodes(self, num, data=None):
"""Add nodes. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, data=None):
"""Add one edge. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, data=None):
"""Add many edges. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
@property
def parent_nid(self):
"""Get the parent node ids.
The returned tensor can be used as a map from the node id
in this subgraph to the node id in the parent graph.
Returns
-------
Tensor
The parent node id array.
"""
return self._parent_nid.tousertensor()
def _get_parent_eid(self):
# The parent eid might be lazily evaluated and thus may not
# be an index. Instead, it's a lambda function that returns
# an index.
if isinstance(self._parent_eid, utils.Index):
return self._parent_eid
else:
return self._parent_eid()
@property
def parent_eid(self):
"""Get the parent edge ids.
The returned tensor can be used as a map from the edge id
in this subgraph to the edge id in the parent graph.
Returns
-------
Tensor
The parent edge id array.
"""
return self._get_parent_eid().tousertensor()
def copy_to_parent(self, inplace=False):
"""Write node/edge features to the parent graph.
Parameters
----------
inplace : bool
If true, use inplace write (no gradient but faster)
"""
self._parent._node_frame.update_rows(
self._parent_nid, self._node_frame, inplace=inplace)
if self._parent._edge_frame.num_rows != 0:
self._parent._edge_frame.update_rows(
self._get_parent_eid(), self._edge_frame, inplace=inplace)
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
All old features will be removed.
"""
if self._parent._node_frame.num_rows != 0 and self._parent._node_frame.num_columns != 0:
self._node_frame = FrameRef(Frame(
self._parent._node_frame[self._parent_nid]))
if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0:
self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[self._get_parent_eid()]))
def map_to_subgraph_nid(self, parent_vids):
"""Map the node Ids in the parent graph to the node Ids in the subgraph.
Parameters
----------
parent_vids : list, tensor
The node ID array in the parent graph.
Returns
-------
tensor
The node ID array in the subgraph.
"""
v = map_to_subgraph_nid(self._subgraph_index, utils.toindex(parent_vids))
return v.tousertensor()
...@@ -7,12 +7,11 @@ from ._ffi.function import _init_api ...@@ -7,12 +7,11 @@ from ._ffi.function import _init_api
from .graph import DGLGraph from .graph import DGLGraph
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from . import ndarray as nd from . import ndarray as nd
from .subgraph import DGLSubGraph
from . import backend as F from . import backend as F
from .graph_index import from_coo from .graph_index import from_coo
from .graph_index import _get_halo_subgraph_inner_node from .graph_index import _get_halo_subgraph_inner_node
from .graph_index import _get_halo_subgraph_inner_edge from .graph_index import _get_halo_subgraph_inner_edge
from .batched_graph import BatchedDGLGraph, unbatch from .graph import unbatch
from .convert import graph, bipartite from .convert import graph, bipartite
from . import utils from . import utils
from .base import EID, NID from .base import EID, NID
...@@ -250,7 +249,6 @@ def reverse(g, share_ndata=False, share_edata=False): ...@@ -250,7 +249,6 @@ def reverse(g, share_ndata=False, share_edata=False):
Notes Notes
----- -----
* This function does not support :class:`~dgl.BatchedDGLGraph` objects.
* We do not dynamically update the topology of a graph once that of its reverse changes. * We do not dynamically update the topology of a graph once that of its reverse changes.
This can be particularly problematic when the node/edge attrs are shared. For example, This can be particularly problematic when the node/edge attrs are shared. For example,
if the topology of both the original graph and its reverse get changed independently, if the topology of both the original graph and its reverse get changed independently,
...@@ -307,12 +305,12 @@ def reverse(g, share_ndata=False, share_edata=False): ...@@ -307,12 +305,12 @@ def reverse(g, share_ndata=False, share_edata=False):
[2.], [2.],
[3.]]) [3.]])
""" """
assert not isinstance(g, BatchedDGLGraph), \
'reverse is not supported for a BatchedDGLGraph object'
g_reversed = DGLGraph(multigraph=g.is_multigraph) g_reversed = DGLGraph(multigraph=g.is_multigraph)
g_reversed.add_nodes(g.number_of_nodes()) g_reversed.add_nodes(g.number_of_nodes())
g_edges = g.all_edges(order='eid') g_edges = g.all_edges(order='eid')
g_reversed.add_edges(g_edges[1], g_edges[0]) g_reversed.add_edges(g_edges[1], g_edges[0])
g_reversed._batch_num_nodes = g._batch_num_nodes
g_reversed._batch_num_edges = g._batch_num_edges
if share_ndata: if share_ndata:
g_reversed._node_frame = g._node_frame g_reversed._node_frame = g._node_frame
if share_edata: if share_edata:
...@@ -391,17 +389,14 @@ def laplacian_lambda_max(g): ...@@ -391,17 +389,14 @@ def laplacian_lambda_max(g):
Parameters Parameters
---------- ----------
g : DGLGraph or BatchedDGLGraph g : DGLGraph
The input graph, it should be an undirected graph. The input graph, it should be an undirected graph.
Returns Returns
------- -------
list : list :
* If the input g is a DGLGraph, the returned value would be Return a list, where the i-th item indicates the largest eigenvalue
a list with one element, indicating the largest eigenvalue of g. of i-th graph in g.
* If the input g is a BatchedDGLGraph, the returned value would
be a list, where the i-th item indicates the largest eigenvalue
of i-th graph in g.
Examples Examples
-------- --------
...@@ -413,11 +408,7 @@ def laplacian_lambda_max(g): ...@@ -413,11 +408,7 @@ def laplacian_lambda_max(g):
>>> dgl.laplacian_lambda_max(g) >>> dgl.laplacian_lambda_max(g)
[1.809016994374948] [1.809016994374948]
""" """
if isinstance(g, BatchedDGLGraph): g_arr = unbatch(g)
g_arr = unbatch(g)
else:
g_arr = [g]
rst = [] rst = []
for g_i in g_arr: for g_i in g_arr:
n = g_i.number_of_nodes() n = g_i.number_of_nodes()
...@@ -573,7 +564,7 @@ def partition_graph_with_halo(g, node_part, num_hops): ...@@ -573,7 +564,7 @@ def partition_graph_with_halo(g, node_part, num_hops):
for i, subg in enumerate(subgs): for i, subg in enumerate(subgs):
inner_node = _get_halo_subgraph_inner_node(subg) inner_node = _get_halo_subgraph_inner_node(subg)
inner_edge = _get_halo_subgraph_inner_edge(subg) inner_edge = _get_halo_subgraph_inner_edge(subg)
subg = DGLSubGraph(g, subg) subg = g._create_subgraph(subg, subg.induced_nodes, subg.induced_edges)
inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack()) inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack())
subg.ndata['inner_node'] = inner_node subg.ndata['inner_node'] = inner_node
inner_edge = F.zerocopy_from_dlpack(inner_edge.to_dlpack()) inner_edge = F.zerocopy_from_dlpack(inner_edge.to_dlpack())
......
...@@ -124,8 +124,8 @@ def test_node_subgraph(): ...@@ -124,8 +124,8 @@ def test_node_subgraph():
subig = ig.node_subgraph(utils.toindex(randv)) subig = ig.node_subgraph(utils.toindex(randv))
check_basics(subg.graph, subig.graph) check_basics(subg.graph, subig.graph)
check_graph_equal(subg.graph, subig.graph) check_graph_equal(subg.graph, subig.graph)
assert F.asnumpy(map_to_subgraph_nid(subg, utils.toindex(randv1[0:10])).tousertensor() assert F.asnumpy(map_to_subgraph_nid(subg.induced_nodes, utils.toindex(randv1[0:10])).tousertensor()
== map_to_subgraph_nid(subig, utils.toindex(randv1[0:10])).tousertensor()).sum(0).item() == 10 == map_to_subgraph_nid(subig.induced_nodes, utils.toindex(randv1[0:10])).tousertensor()).sum(0).item() == 10
# node_subgraphs # node_subgraphs
randvs = [] randvs = []
......
...@@ -195,7 +195,7 @@ def test_softmax_edges(): ...@@ -195,7 +195,7 @@ def test_softmax_edges():
def test_broadcast_nodes(): def test_broadcast_nodes():
# test#1: basic # test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10)) g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,)) feat0 = F.randn((1, 40))
ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0) ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0)
assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth) assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth)
...@@ -204,23 +204,23 @@ def test_broadcast_nodes(): ...@@ -204,23 +204,23 @@ def test_broadcast_nodes():
g2 = dgl.DGLGraph() g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12)) g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3]) bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,)) feat1 = F.randn((1, 40))
feat2 = F.randn((40,)) feat2 = F.randn((1, 40))
feat3 = F.randn((40,)) feat3 = F.randn((1, 40))
ground_truth = F.stack( ground_truth = F.cat(
[feat0] * g0.number_of_nodes() +\ [feat0] * g0.number_of_nodes() +\
[feat1] * g1.number_of_nodes() +\ [feat1] * g1.number_of_nodes() +\
[feat2] * g2.number_of_nodes() +\ [feat2] * g2.number_of_nodes() +\
[feat3] * g3.number_of_nodes(), 0 [feat3] * g3.number_of_nodes(), 0
) )
assert F.allclose(dgl.broadcast_nodes( assert F.allclose(dgl.broadcast_nodes(
bg, F.stack([feat0, feat1, feat2, feat3], 0) bg, F.cat([feat0, feat1, feat2, feat3], 0)
), ground_truth) ), ground_truth)
def test_broadcast_edges(): def test_broadcast_edges():
# test#1: basic # test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10)) g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,)) feat0 = F.randn((1, 40))
ground_truth = F.stack([feat0] * g0.number_of_edges(), 0) ground_truth = F.stack([feat0] * g0.number_of_edges(), 0)
assert F.allclose(dgl.broadcast_edges(g0, feat0), ground_truth) assert F.allclose(dgl.broadcast_edges(g0, feat0), ground_truth)
...@@ -229,17 +229,17 @@ def test_broadcast_edges(): ...@@ -229,17 +229,17 @@ def test_broadcast_edges():
g2 = dgl.DGLGraph() g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12)) g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3]) bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,)) feat1 = F.randn((1, 40))
feat2 = F.randn((40,)) feat2 = F.randn((1, 40))
feat3 = F.randn((40,)) feat3 = F.randn((1, 40))
ground_truth = F.stack( ground_truth = F.cat(
[feat0] * g0.number_of_edges() +\ [feat0] * g0.number_of_edges() +\
[feat1] * g1.number_of_edges() +\ [feat1] * g1.number_of_edges() +\
[feat2] * g2.number_of_edges() +\ [feat2] * g2.number_of_edges() +\
[feat3] * g3.number_of_edges(), 0 [feat3] * g3.number_of_edges(), 0
) )
assert F.allclose(dgl.broadcast_edges( assert F.allclose(dgl.broadcast_edges(
bg, F.stack([feat0, feat1, feat2, feat3], 0) bg, F.cat([feat0, feat1, feat2, feat3], 0)
), ground_truth) ), ground_truth)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -41,13 +41,13 @@ def test_basics(): ...@@ -41,13 +41,13 @@ def test_basics():
eid = {2, 3, 4, 5, 10, 11, 12, 13, 16} eid = {2, 3, 4, 5, 10, 11, 12, 13, 16}
assert set(F.zerocopy_to_numpy(sg.parent_eid)) == eid assert set(F.zerocopy_to_numpy(sg.parent_eid)) == eid
eid = F.tensor(sg.parent_eid) eid = F.tensor(sg.parent_eid)
# the subgraph is empty initially # the subgraph is empty initially except for NID/EID field
assert len(sg.ndata) == 0
assert len(sg.edata) == 0
# the data is copied after explict copy from
sg.copy_from_parent()
assert len(sg.ndata) == 1 assert len(sg.ndata) == 1
assert len(sg.edata) == 1 assert len(sg.edata) == 1
# the data is copied after explict copy from
sg.copy_from_parent()
assert len(sg.ndata) == 2
assert len(sg.edata) == 2
sh = sg.ndata['h'] sh = sg.ndata['h']
assert F.allclose(F.gather_row(h, F.tensor(nid)), sh) assert F.allclose(F.gather_row(h, F.tensor(nid)), sh)
''' '''
......
...@@ -328,7 +328,7 @@ def test_set2set(): ...@@ -328,7 +328,7 @@ def test_set2set():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(g, h0) h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g]) bg = dgl.batch([g, g, g])
...@@ -346,7 +346,7 @@ def test_glob_att_pool(): ...@@ -346,7 +346,7 @@ def test_glob_att_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0) h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g, g]) bg = dgl.batch([g, g, g, g])
...@@ -366,13 +366,13 @@ def test_simple_pool(): ...@@ -366,13 +366,13 @@ def test_simple_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(g, h0) h1 = sum_pool(g, h0)
check_close(h1, F.sum(h0, 0)) check_close(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0) h1 = avg_pool(g, h0)
check_close(h1, F.mean(h0, 0)) check_close(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0) h1 = max_pool(g, h0)
check_close(h1, F.max(h0, 0)) check_close(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0) h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
# test#2: batched graph # test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5)) g_ = dgl.DGLGraph(nx.path_graph(5))
......
...@@ -124,7 +124,7 @@ def test_set2set(): ...@@ -124,7 +124,7 @@ def test_set2set():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(g, h0) h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
# test#2: batched graph # test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(11)) g1 = dgl.DGLGraph(nx.path_graph(11))
...@@ -145,7 +145,7 @@ def test_glob_att_pool(): ...@@ -145,7 +145,7 @@ def test_glob_att_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0) h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g, g]) bg = dgl.batch([g, g, g, g])
...@@ -170,13 +170,13 @@ def test_simple_pool(): ...@@ -170,13 +170,13 @@ def test_simple_pool():
max_pool = max_pool.to(ctx) max_pool = max_pool.to(ctx)
sort_pool = sort_pool.to(ctx) sort_pool = sort_pool.to(ctx)
h1 = sum_pool(g, h0) h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0) h1 = avg_pool(g, h0)
assert F.allclose(h1, F.mean(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0) h1 = max_pool(g, h0)
assert F.allclose(h1, F.max(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0) h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.dim() == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
# test#2: batched graph # test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5)) g_ = dgl.DGLGraph(nx.path_graph(5))
...@@ -228,7 +228,7 @@ def test_set_trans(): ...@@ -228,7 +228,7 @@ def test_set_trans():
h1 = st_enc_1(g, h0) h1 = st_enc_1(g, h0)
assert h1.shape == h0.shape assert h1.shape == h0.shape
h2 = st_dec(g, h1) h2 = st_dec(g, h1)
assert h2.shape[0] == 200 and h2.dim() == 1 assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
# test#2: batched graph # test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5)) g1 = dgl.DGLGraph(nx.path_graph(5))
......
...@@ -93,13 +93,13 @@ def test_simple_pool(): ...@@ -93,13 +93,13 @@ def test_simple_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(g, h0) h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0) h1 = avg_pool(g, h0)
assert F.allclose(h1, F.mean(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0) h1 = max_pool(g, h0)
assert F.allclose(h1, F.max(h0, 0)) assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0) h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
# test#2: batched graph # test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5)) g_ = dgl.DGLGraph(nx.path_graph(5))
...@@ -246,7 +246,7 @@ def test_glob_att_pool(): ...@@ -246,7 +246,7 @@ def test_glob_att_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0) h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1 assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g, g]) bg = dgl.batch([g, g, g, g])
......
...@@ -91,12 +91,13 @@ def plot_tree(g): ...@@ -91,12 +91,13 @@ def plot_tree(g):
plot_tree(graph.to_networkx()) plot_tree(graph.to_networkx())
################################################################################# #################################################################################
# You can read more about the definition of :func:`~dgl.batched_graph.batch`, or # You can read more about the definition of :func:`~dgl.batch`, or
# skip ahead to the next step: # skip ahead to the next step:
# .. note:: # .. note::
# #
# **Definition**: A :class:`~dgl.batched_graph.BatchedDGLGraph` is a # **Definition**: :func:`~dgl.batch` unions a list of :math:`B`
# :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s. # :class:`~dgl.DGLGraph`\ s and returns a :class:`~dgl.DGLGraph` of batch
# size :math:`B`.
# #
# - The union includes all the nodes, # - The union includes all the nodes,
# edges, and their features. The order of nodes, edges, and features are # edges, and their features. The order of nodes, edges, and features are
...@@ -108,23 +109,16 @@ plot_tree(graph.to_networkx()) ...@@ -108,23 +109,16 @@ plot_tree(graph.to_networkx())
# :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph. # :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph.
# #
# - Therefore, performing feature transformation and message passing on # - Therefore, performing feature transformation and message passing on
# ``BatchedDGLGraph`` is equivalent to doing those # the batched graph is equivalent to doing those
# on all ``DGLGraph`` constituents in parallel. # on all ``DGLGraph`` constituents in parallel.
# #
# - Duplicate references to the same graph are # - Duplicate references to the same graph are
# treated as deep copies; the nodes, edges, and features are duplicated, # treated as deep copies; the nodes, edges, and features are duplicated,
# and mutation on one reference does not affect the other. # and mutation on one reference does not affect the other.
# - Currently, ``BatchedDGLGraph`` is immutable in # - The batched graph keeps track of the meta
# graph structure. You can't add
# nodes and edges to it. You need to support mutable batched graphs in
# (far) future.
# - The ``BatchedDGLGraph`` keeps track of the meta
# information of the constituents so it can be # information of the constituents so it can be
# :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s. # :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s.
# #
# For more details about the :class:`~dgl.batched_graph.BatchedDGLGraph`
# module in DGL, you can click the class name.
#
# Step 2: Tree-LSTM cell with message-passing APIs # Step 2: Tree-LSTM cell with message-passing APIs
# ------------------------------------------------ # ------------------------------------------------
# #
......
...@@ -798,9 +798,8 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) ...@@ -798,9 +798,8 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__ # <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__
# , and it is worth explaining one more time why this is so. # , and it is worth explaining one more time why this is so.
# #
# By batching many small graphs, DGL internally maintains a large *container* # By batching many small graphs, DGL parallels message passing on each individual
# graph (``BatchedDGLGraph``) over which ``update_all`` propels message-passing # graphs of a batch.
# on all the edges and nodes.
# #
# With ``dgl.batch``, you merge ``g_{1}, ..., g_{N}`` into one single giant # With ``dgl.batch``, you merge ``g_{1}, ..., g_{N}`` into one single giant
# graph consisting of :math:`N` isolated small graphs. For example, if we # graph consisting of :math:`N` isolated small graphs. For example, if we
...@@ -833,7 +832,7 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) ...@@ -833,7 +832,7 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# internally groups nodes with the same in-degrees and calls reduce UDF once # internally groups nodes with the same in-degrees and calls reduce UDF once
# for each group. Thus, batching also reduces number of calls to these UDFs. # for each group. Thus, batching also reduces number of calls to these UDFs.
# #
# The modification of the node/edge features of a ``BatchedDGLGraph`` object # The modification of the node/edge features of the batched graph object
# does not take effect on the features of the original small graphs, so we # does not take effect on the features of the original small graphs, so we
# need to replace the old graph list with the new graph list # need to replace the old graph list with the new graph list
# ``g_list = dgl.unbatch(bg)``. # ``g_list = dgl.unbatch(bg)``.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment