"tests/vscode:/vscode.git/clone" did not exist on "0d9b6bfd5a1ac04ec9727c14fc662b4942700b24"
Unverified Commit 93ac29ce authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Unify DGLGraph, BatchedDGLGraph and DGLSubGraph (#1216)



* upd

* upd

* upd

* lint

* fix

* fix test

* fix

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd tutorial

* upd

* upd

* fix kg

* upd doc organization

* refresh test

* upd

* refactor doc

* fix lint
Co-authored-by: default avatarMinjie Wang <minjie.wang@nyu.edu>
parent 8874e830
......@@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .gnn import GCNLayer, GATLayer
from ...batched_graph import BatchedDGLGraph, max_nodes
from ...readout import max_nodes
from ...nn.pytorch import WeightAndSum
from ...contrib.deprecation import deprecated
......@@ -74,13 +74,13 @@ class BaseGNNClassifier(nn.Module):
self.soft_classifier = MLPBinaryClassifier(
self.g_feats, classifier_hidden_feats, n_tasks, dropout)
def forward(self, bg, feats):
def forward(self, g, feats):
"""Multi-task prediction for a batch of molecules
Parameters
----------
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M0)
Initial features for all atoms in the batch of molecules
......@@ -91,18 +91,15 @@ class BaseGNNClassifier(nn.Module):
"""
# Update atom features with GNNs
for gnn in self.gnn_layers:
feats = gnn(bg, feats)
feats = gnn(g, feats)
# Compute molecule features from atom features
h_g_sum = self.weighted_sum_readout(bg, feats)
h_g_sum = self.weighted_sum_readout(g, feats)
with bg.local_scope():
bg.ndata['h'] = feats
h_g_max = max_nodes(bg, 'h')
with g.local_scope():
g.ndata['h'] = feats
h_g_max = max_nodes(g, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
h_g_max = h_g_max.unsqueeze(0)
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
# Multi-task prediction
......
......@@ -42,13 +42,13 @@ class GCNLayer(nn.Module):
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, bg, feats):
def forward(self, g, feats):
"""Update atom representations
Parameters
----------
bg : BatchedDGLGraph
Batched DGLGraphs for processing multiple molecules in parallel
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
......@@ -58,7 +58,7 @@ class GCNLayer(nn.Module):
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size, must match out_feats in initialization
"""
new_feats = self.graph_conv(bg, feats)
new_feats = self.graph_conv(g, feats)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
......
......@@ -7,7 +7,7 @@ import torch.nn.functional as F
import rdkit.Chem as Chem
from ....batched_graph import batch, unbatch
from ....graph import batch, unbatch
from ....contrib.deprecation import deprecated
from ....data.utils import get_download_dir
......
......@@ -63,7 +63,7 @@ class ChebConv(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
......
......@@ -3,8 +3,7 @@
from mxnet import gluon, nd
from mxnet.gluon import nn
from ... import BatchedDGLGraph
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
......@@ -24,7 +23,7 @@ class SumPooling(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
......@@ -33,9 +32,8 @@ class SumPooling(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -61,7 +59,7 @@ class AvgPooling(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
......@@ -70,9 +68,8 @@ class AvgPooling(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -98,7 +95,7 @@ class MaxPooling(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
......@@ -107,9 +104,8 @@ class MaxPooling(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -139,7 +135,7 @@ class SortPooling(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
......@@ -148,9 +144,8 @@ class SortPooling(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(k * D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, k * D)`.
The output feature with shape :math:`(B, k * D)`, where
:math:`B` refers to the batch size.
"""
# Sort the feature of each node in ascending order.
with graph.local_scope():
......@@ -159,10 +154,7 @@ class SortPooling(nn.Block):
# Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k)[0].reshape(
-1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph):
return ret
else:
return ret.squeeze(axis=0)
return ret
def __repr__(self):
return 'SortPooling(k={})'.format(self.k)
......@@ -195,7 +187,7 @@ class GlobalAttentionPooling(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
......@@ -204,9 +196,8 @@ class GlobalAttentionPooling(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
......@@ -263,7 +254,7 @@ class Set2Set(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
......@@ -272,14 +263,11 @@ class Set2Set(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
batch_size = 1
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
batch_size = graph.batch_size
h = (nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context),
nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context))
......@@ -288,23 +276,14 @@ class Set2Set(nn.Block):
for _ in range(self.n_iters):
q, h = self.lstm(q_star.expand_dims(axis=0), h)
q = q.reshape((batch_size, self.input_dim))
e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True)
graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r')
if readout.ndim == 1: # graph is not a BatchedDGLGraph
readout = readout.expand_dims(0)
q_star = nd.concat(q, readout, dim=-1)
if isinstance(graph, BatchedDGLGraph):
return q_star
else:
return q_star.squeeze(axis=0)
return q_star
def __repr__(self):
summary = 'Set2Set('
......
......@@ -203,7 +203,7 @@ class AtomicConv(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
Topology based on which message passing is performed.
feat : Float32 tensor of shape (V, 1)
Initial node features, which are atomic numbers in the paper.
......
......@@ -68,7 +68,7 @@ class ChebConv(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
......
......@@ -4,9 +4,8 @@ import torch as th
import torch.nn as nn
import numpy as np
from ... import BatchedDGLGraph
from ...backend import pytorch as F
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
......@@ -29,7 +28,7 @@ class SumPooling(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -38,9 +37,8 @@ class SumPooling(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -62,7 +60,7 @@ class AvgPooling(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -71,9 +69,8 @@ class AvgPooling(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -95,7 +92,7 @@ class MaxPooling(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -104,9 +101,8 @@ class MaxPooling(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -132,7 +128,7 @@ class SortPooling(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
......@@ -141,9 +137,8 @@ class SortPooling(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(k * D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, k * D)`.
The output feature with shape :math:`(B, k * D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
# Sort the feature of each node in ascending order.
......@@ -152,10 +147,7 @@ class SortPooling(nn.Module):
# Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k, idx=-1)[0].view(
-1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph):
return ret
else:
return ret.squeeze(0)
return ret
class GlobalAttentionPooling(nn.Module):
......@@ -193,9 +185,8 @@ class GlobalAttentionPooling(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
......@@ -257,7 +248,7 @@ class Set2Set(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
......@@ -266,14 +257,11 @@ class Set2Set(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
batch_size = 1
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
batch_size = graph.batch_size
h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
feat.new_zeros((self.n_layers, batch_size, self.input_dim)))
......@@ -283,23 +271,14 @@ class Set2Set(nn.Module):
for _ in range(self.n_iters):
q, h = self.lstm(q_star.unsqueeze(0), h)
q = q.view(batch_size, self.input_dim)
e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True)
graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r')
if readout.dim() == 1: # graph is not a BatchedDGLGraph
readout = readout.unsqueeze(0)
q_star = th.cat([q, readout], dim=-1)
if isinstance(graph, BatchedDGLGraph):
return q_star
else:
return q_star.squeeze(0)
return q_star
def extra_repr(self):
"""Set the extra representation of the module.
......@@ -574,7 +553,7 @@ class SetTransformerEncoder(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
......@@ -585,14 +564,9 @@ class SetTransformerEncoder(nn.Module):
torch.Tensor
The output feature with shape :math:`(N, D)`.
"""
if isinstance(graph, BatchedDGLGraph):
lengths = graph.batch_num_nodes
else:
lengths = [graph.number_of_nodes()]
lengths = graph.batch_num_nodes
for layer in self.layers:
feat = layer(feat, lengths)
return feat
......@@ -640,7 +614,7 @@ class SetTransformerDecoder(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
......@@ -649,25 +623,15 @@ class SetTransformerDecoder(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
if isinstance(graph, BatchedDGLGraph):
len_pma = graph.batch_num_nodes
len_sab = [self.k] * graph.batch_size
else:
len_pma = [graph.number_of_nodes()]
len_sab = [self.k]
len_pma = graph.batch_num_nodes
len_sab = [self.k] * graph.batch_size
feat = self.pma(feat, len_pma)
for layer in self.layers:
feat = layer(feat, len_sab)
if isinstance(graph, BatchedDGLGraph):
return feat.view(graph.batch_size, self.k * self.d_model)
else:
return feat.view(self.k * self.d_model)
return feat.view(graph.batch_size, self.k * self.d_model)
class WeightAndSum(nn.Module):
......@@ -686,13 +650,13 @@ class WeightAndSum(nn.Module):
nn.Sigmoid()
)
def forward(self, bg, feats):
def forward(self, g, feats):
"""Compute molecule representations out of atom representations
Parameters
----------
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules
* N is the total number of atoms in all molecules
......@@ -702,9 +666,9 @@ class WeightAndSum(nn.Module):
FloatTensor of shape (B, self.in_feats)
Representations for B molecules
"""
with bg.local_scope():
bg.ndata['h'] = feats
bg.ndata['w'] = self.atom_weighting(bg.ndata['h'])
h_g_sum = sum_nodes(bg, 'h', 'w')
with g.local_scope():
g.ndata['h'] = feats
g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(g, 'h', 'w')
return h_g_sum
......@@ -4,8 +4,7 @@ import tensorflow as tf
from tensorflow.keras import layers
from ... import BatchedDGLGraph
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, \
from ...readout import sum_nodes, mean_nodes, max_nodes, \
softmax_nodes, topk_nodes
......@@ -29,7 +28,7 @@ class SumPooling(layers.Layer):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -38,9 +37,8 @@ class SumPooling(layers.Layer):
Returns
-------
tf.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -63,7 +61,7 @@ class AvgPooling(layers.Layer):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -72,9 +70,8 @@ class AvgPooling(layers.Layer):
Returns
-------
tf.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -97,7 +94,7 @@ class MaxPooling(layers.Layer):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -106,9 +103,8 @@ class MaxPooling(layers.Layer):
Returns
-------
tf.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -135,7 +131,7 @@ class SortPooling(layers.Layer):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, D)` where
......@@ -144,9 +140,8 @@ class SortPooling(layers.Layer):
Returns
-------
tf.Tensor
The output feature with shape :math:`(k * D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, k * D)`.
The output feature with shape :math:`(B, k * D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
# Sort the feature of each node in ascending order.
......@@ -155,10 +150,7 @@ class SortPooling(layers.Layer):
# Sort nodes according to their last features.
ret = tf.reshape(topk_nodes(graph, 'h', self.k, idx=-1)[0], (
-1, self.k * feat.shape[-1]))
if isinstance(graph, BatchedDGLGraph):
return ret
else:
return tf.squeeze(ret, 0)
return ret
class GlobalAttentionPooling(layers.Layer):
......@@ -197,9 +189,8 @@ class GlobalAttentionPooling(layers.Layer):
Returns
-------
tf.Tensor
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
......@@ -234,13 +225,13 @@ class WeightAndSum(layers.Layer):
layers.Activation(tf.nn.sigmoid)
)
def call(self, bg, feats):
def call(self, g, feats):
"""Compute molecule representations out of atom representations
Parameters
----------
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules
* N is the total number of atoms in all molecules
......@@ -250,9 +241,9 @@ class WeightAndSum(layers.Layer):
FloatTensor of shape (B, self.in_feats)
Representations for B molecules
"""
with bg.local_scope():
bg.ndata['h'] = feats
bg.ndata['w'] = self.atom_weighting(bg.ndata['h'])
h_g_sum = sum_nodes(bg, 'h', 'w')
with g.local_scope():
g.ndata['h'] = feats
g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(g, 'h', 'w')
return h_g_sum
"""Classes and functions for batching multiple graphs together."""
from __future__ import absolute_import
from collections.abc import Iterable
import numpy as np
from .base import ALL, is_all, DGLError
from .frame import FrameRef, Frame
from .graph import DGLGraph
from . import graph_index as gi
from .base import DGLError
from . import backend as F
from . import utils
__all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split',
'sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges',
__all__ = ['sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges',
'max_nodes', 'max_edges', 'softmax_nodes', 'softmax_edges',
'broadcast_nodes', 'broadcast_edges', 'topk_nodes', 'topk_edges']
class BatchedDGLGraph(DGLGraph):
"""Class for batched DGL graphs.
A :class:`BatchedDGLGraph` basically merges a list of small graphs into a giant
graph so that one can perform message passing and readout over a batch of graphs
simultaneously.
The nodes and edges are re-indexed with a new id in the batched graph with the
rule below:
====== ========== ======================== === ==========================
item Graph 1 Graph 2 ... Graph k
====== ========== ======================== === ==========================
raw id 0, ..., N1 0, ..., N2 ... ..., Nk
new id 0, ..., N1 N1 + 1, ..., N1 + N2 + 1 ... ..., N1 + ... + Nk + k - 1
====== ========== ======================== === ==========================
The batched graph is read-only, i.e. one cannot further add nodes and edges.
A ``RuntimeError`` will be raised if one attempts.
To modify the features in :class:`BatchedDGLGraph` has no effect on the original
graphs. See the examples below about how to work around.
Parameters
----------
graph_list : iterable
A collection of :class:`~dgl.DGLGraph` objects to be batched.
node_attrs : None, str or iterable, optional
The node attributes to be batched. If ``None``, the :class:`BatchedDGLGraph` object
will not have any node attributes. By default, all node attributes will be batched.
An error will be raised if graphs having nodes have different attributes. If ``str``
or ``iterable``, this should specify exactly what node attributes to be batched.
edge_attrs : None, str or iterable, optional
Same as for the case of :attr:`node_attrs`
Examples
--------
Create two :class:`~dgl.DGLGraph` objects.
**Instantiation:**
>>> import dgl
>>> import torch as th
>>> g1 = dgl.DGLGraph()
>>> g1.add_nodes(2) # Add 2 nodes
>>> g1.add_edge(0, 1) # Add edge 0 -> 1
>>> g1.ndata['hv'] = th.tensor([[0.], [1.]]) # Initialize node features
>>> g1.edata['he'] = th.tensor([[0.]]) # Initialize edge features
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3) # Add 3 nodes
>>> g2.add_edges([0, 2], [1, 1]) # Add edges 0 -> 1, 2 -> 1
>>> g2.ndata['hv'] = th.tensor([[2.], [3.], [4.]]) # Initialize node features
>>> g2.edata['he'] = th.tensor([[1.], [2.]]) # Initialize edge features
Merge two :class:`~dgl.DGLGraph` objects into one :class:`BatchedDGLGraph` object.
When merging a list of graphs, we can choose to include only a subset of the attributes.
>>> bg = dgl.batch([g1, g2], edge_attrs=None)
>>> bg.edata
{}
Below one can see that the nodes are re-indexed. The edges are re-indexed in
the same way.
>>> bg.nodes()
tensor([0, 1, 2, 3, 4])
>>> bg.ndata['hv']
tensor([[0.],
[1.],
[2.],
[3.],
[4.]])
**Property:**
We can still get a brief summary of the graphs that constitute the batched graph.
>>> bg.batch_size
2
>>> bg.batch_num_nodes
[2, 3]
>>> bg.batch_num_edges
[1, 2]
**Readout:**
Another common demand for graph neural networks is graph readout, which is a
function that takes in the node attributes and/or edge attributes for a graph
and outputs a vector summarizing the information in the graph.
:class:`BatchedDGLGraph` also supports performing readout for a batch of graphs at once.
Below we take the built-in readout function :func:`sum_nodes` as an example, which
sums over a particular kind of node attribute for each graph.
>>> dgl.sum_nodes(bg, 'hv') # Sum the node attribute 'hv' for each graph.
tensor([[1.], # 0 + 1
[9.]]) # 2 + 3 + 4
**Message passing:**
For message passing and related operations, :class:`BatchedDGLGraph` acts exactly
the same as :class:`~dgl.DGLGraph`.
**Update Attributes:**
Updating the attributes of the batched graph has no effect on the original graphs.
>>> bg.edata['he'] = th.zeros(3, 2)
>>> g2.edata['he']
tensor([[1.],
[2.]])}
Instead, we can decompose the batched graph back into a list of graphs and use them
to replace the original graphs.
>>> g1, g2 = dgl.unbatch(bg) # returns a list of DGLGraph objects
>>> g2.edata['he']
tensor([[0., 0.],
[0., 0.]])}
"""
def __init__(self, graph_list, node_attrs, edge_attrs):
def _get_num_item_and_attr_types(g, mode):
if mode == 'node':
num_items = g.number_of_nodes()
attr_types = set(g.node_attr_schemes().keys())
elif mode == 'edge':
num_items = g.number_of_edges()
attr_types = set(g.edge_attr_schemes().keys())
return num_items, attr_types
def _init_attrs(attrs, mode):
if attrs is None:
return []
elif is_all(attrs):
attrs = set()
# Check if at least a graph has mode items and associated features.
for i, g in enumerate(graph_list):
g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode)
if g_num_items > 0 and len(g_attrs) > 0:
attrs = g_attrs
ref_g_index = i
break
# Check if all the graphs with mode items have the same associated features.
if len(attrs) > 0:
for i, g in enumerate(graph_list):
g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode)
if g_attrs != attrs and g_num_items > 0:
raise ValueError('Expect graph {0} and {1} to have the same {2} '
'attributes when {2}_attrs=ALL, got {3} and {4}.'
.format(ref_g_index, i, mode, attrs, g_attrs))
return attrs
elif isinstance(attrs, str):
return [attrs]
elif isinstance(attrs, Iterable):
return attrs
else:
raise ValueError('Expected {} attrs to be of type None str or Iterable, '
'got type {}'.format(mode, type(attrs)))
node_attrs = _init_attrs(node_attrs, 'node')
edge_attrs = _init_attrs(edge_attrs, 'edge')
# create batched graph index
batched_index = gi.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames
if len(node_attrs) == 0:
batched_node_frame = FrameRef(Frame(num_rows=batched_index.number_of_nodes()))
else:
# NOTE: following code will materialize the columns of the input graphs.
cols = {key: F.cat([gr._node_frame[key] for gr in graph_list
if gr.number_of_nodes() > 0], dim=0)
for key in node_attrs}
batched_node_frame = FrameRef(Frame(cols))
if len(edge_attrs) == 0:
batched_edge_frame = FrameRef(Frame(num_rows=batched_index.number_of_edges()))
else:
cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list
if gr.number_of_edges() > 0], dim=0)
for key in edge_attrs}
batched_edge_frame = FrameRef(Frame(cols))
super(BatchedDGLGraph, self).__init__(graph_data=batched_index,
node_frame=batched_node_frame,
edge_frame=batched_edge_frame)
# extra members
self._batch_size = 0
self._batch_num_nodes = []
self._batch_num_edges = []
for grh in graph_list:
if isinstance(grh, BatchedDGLGraph):
# handle the input is again a batched graph.
self._batch_size += grh._batch_size
self._batch_num_nodes += grh._batch_num_nodes
self._batch_num_edges += grh._batch_num_edges
else:
self._batch_size += 1
self._batch_num_nodes.append(grh.number_of_nodes())
self._batch_num_edges.append(grh.number_of_edges())
@property
def batch_size(self):
"""Number of graphs in this batch.
Returns
-------
int
Number of graphs in this batch."""
return self._batch_size
@property
def batch_num_nodes(self):
"""Number of nodes of each graph in this batch.
Returns
-------
list
Number of nodes of each graph in this batch."""
return self._batch_num_nodes
@property
def batch_num_edges(self):
"""Number of edges of each graph in this batch.
Returns
-------
list
Number of edges of each graph in this batch."""
return self._batch_num_edges
# override APIs
def add_nodes(self, num, data=None):
"""Add nodes. Disabled because BatchedDGLGraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, data=None):
"""Add one edge. Disabled because BatchedDGLGraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, data=None):
"""Add many edges. Disabled because BatchedDGLGraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
# new APIs
def __getitem__(self, idx):
"""Slice the batch and return the batch of graphs specified by the idx."""
# TODO
raise NotImplementedError
def __setitem__(self, idx, val):
"""Set the value of the slice. The graph size cannot be changed."""
# TODO
raise NotImplementedError
def split(graph_batch, num_or_size_splits): # pylint: disable=unused-argument
"""Split the batch."""
# TODO(minjie): could follow torch.split syntax
raise NotImplementedError
def unbatch(graph):
"""Return the list of graphs in this batch.
Parameters
----------
graph : BatchedDGLGraph
The batched graph.
Returns
-------
list
A list of :class:`~dgl.DGLGraph` objects whose attributes are obtained
by partitioning the attributes of the :attr:`graph`. The length of the
list is the same as the batch size of :attr:`graph`.
Notes
-----
Unbatching will break each field tensor of the batched graph into smaller
partitions.
For simpler tasks such as node/edge state aggregation, try to use
readout functions.
See Also
--------
batch
"""
assert isinstance(graph, BatchedDGLGraph)
bsize = graph.batch_size
bnn = graph.batch_num_nodes
bne = graph.batch_num_edges
pttns = gi.disjoint_partition(graph._graph, utils.toindex(bnn))
# split the frames
node_frames = [FrameRef(Frame(num_rows=n)) for n in bnn]
edge_frames = [FrameRef(Frame(num_rows=n)) for n in bne]
for attr, col in graph._node_frame.items():
col_splits = F.split(col, bnn, dim=0)
for i in range(bsize):
node_frames[i][attr] = col_splits[i]
for attr, col in graph._edge_frame.items():
col_splits = F.split(col, bne, dim=0)
for i in range(bsize):
edge_frames[i][attr] = col_splits[i]
return [DGLGraph(graph_data=pttns[i],
node_frame=node_frames[i],
edge_frame=edge_frames[i]) for i in range(bsize)]
def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
"""Batch a collection of :class:`~dgl.DGLGraph` and return a
:class:`BatchedDGLGraph` object that is independent of the :attr:`graph_list`.
Parameters
----------
graph_list : iterable
A collection of :class:`~dgl.DGLGraph` to be batched.
node_attrs : None, str or iterable
The node attributes to be batched. If ``None``, the :class:`BatchedDGLGraph`
object will not have any node attributes. By default, all node attributes will
be batched. If ``str`` or iterable, this should specify exactly what node
attributes to be batched.
edge_attrs : None, str or iterable, optional
Same as for the case of :attr:`node_attrs`
Returns
-------
BatchedDGLGraph
One single batched graph
See Also
--------
BatchedDGLGraph
unbatch
"""
return BatchedDGLGraph(graph_list, node_attrs, edge_attrs)
READOUT_ON_ATTRS = {
'nodes': ('ndata', 'batch_num_nodes', 'number_of_nodes'),
'edges': ('edata', 'batch_num_edges', 'number_of_edges'),
......@@ -387,15 +43,12 @@ def _sum_on(graph, typestr, feat, weight):
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(feat) - 1))
feat = weight * feat
if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr)
seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(feat))
y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0)
return y
else:
return F.sum(feat, 0)
n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr)
seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(feat))
y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0)
return y
def sum_nodes(graph, feat, weight=None):
"""Sums all the values of node field :attr:`feat` in :attr:`graph`, optionally
......@@ -420,10 +73,10 @@ def sum_nodes(graph, feat, weight=None):
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of the
corresponding example in the batch. If an example has no nodes,
Return a stacked tensor with an extra first dimension whose size equals
batch size of the input graph.
The i-th row of the stacked tensor contains the readout result of the
i-th graph in the batched graph. If a graph has no nodes,
a zero tensor with the same shape is returned at the corresponding row.
Examples
......@@ -456,7 +109,7 @@ def sum_nodes(graph, feat, weight=None):
for a single graph.
>>> dgl.sum_nodes(g1, 'h', 'w')
tensor([15.]) # 1 * 3 + 2 * 6
tensor([[15.]]) # 1 * 3 + 2 * 6
See Also
--------
......@@ -489,10 +142,10 @@ def sum_edges(graph, feat, weight=None):
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of the
corresponding example in the batch. If an example has no edges,
Return a stacked tensor with an extra first dimension whose size equals
batch size of the input graph.
The i-th row of the stacked tensor contains the readout result of the
i-th graph in the batched graph. If a graph has no edges,
a zero tensor with the same shape is returned at the corresponding row.
Examples
......@@ -527,7 +180,7 @@ def sum_edges(graph, feat, weight=None):
for a single graph.
>>> dgl.sum_edges(g1, 'h', 'w')
tensor([15.]) # 1 * 3 + 2 * 6
tensor([[15.]]) # 1 * 3 + 2 * 6
See Also
--------
......@@ -566,24 +219,17 @@ def _mean_on(graph, typestr, feat, weight):
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(feat) - 1))
feat = weight * feat
if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr)
seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(feat))
if weight is not None:
w = F.unsorted_1d_segment_sum(weight, seg_id, n_graphs, 0)
y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0)
y = y / w
else:
y = F.unsorted_1d_segment_mean(feat, seg_id, n_graphs, 0)
return y
n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr)
seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(feat))
if weight is not None:
w = F.unsorted_1d_segment_sum(weight, seg_id, n_graphs, 0)
y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0)
y = y / w
else:
if weight is None:
return F.mean(feat, 0)
else:
y = F.sum(feat, 0) / F.sum(weight, 0)
return y
y = F.unsorted_1d_segment_mean(feat, seg_id, n_graphs, 0)
return y
def mean_nodes(graph, feat, weight=None):
"""Averages all the values of node field :attr:`feat` in :attr:`graph`,
......@@ -591,7 +237,7 @@ def mean_nodes(graph, feat, weight=None):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : str
The feature field.
......@@ -608,10 +254,10 @@ def mean_nodes(graph, feat, weight=None):
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no nodes,
Return a stacked tensor with an extra first dimension whose size equals
batch size of the input graph.
The i-th row of the stacked tensor contains the readout result of
the i-th graph in the batch. If a graph has no nodes,
a zero tensor with the same shape is returned at the corresponding row.
Examples
......@@ -644,7 +290,7 @@ def mean_nodes(graph, feat, weight=None):
for a single graph.
>>> dgl.mean_nodes(g1, 'h', 'w') # h1 * (w1 / (w1 + w2)) + h2 * (w2 / (w1 + w2))
tensor([1.6667]) # 1 * (3 / (3 + 6)) + 2 * (6 / (3 + 6))
tensor([[1.6667]]) # 1 * (3 / (3 + 6)) + 2 * (6 / (3 + 6))
See Also
--------
......@@ -677,10 +323,10 @@ def mean_edges(graph, feat, weight=None):
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges,
Return a stacked tensor with an extra first dimension whose size equals
batch size of the input graph.
The i-th row of the stacked tensor contains the readout result of
the i-th graph in the batched graph. If a graph has no edges,
a zero tensor with the same shape is returned at the corresponding row.
Examples
......@@ -715,7 +361,7 @@ def mean_edges(graph, feat, weight=None):
for a single graph.
>>> dgl.mean_edges(g1, 'h', 'w') # h1 * (w1 / (w1 + w2)) + h2 * (w2 / (w1 + w2))
tensor([1.6667]) # 1 * (3 / (3 + 6)) + 2 * (6 / (3 + 6))
tensor([[1.6667]]) # 1 * (3 / (3 + 6)) + 2 * (6 / (3 + 6))
See Also
--------
......@@ -750,12 +396,10 @@ def _max_on(graph, typestr, feat):
# TODO: the current solution pads the different graph sizes to the same,
# a more efficient way is to use segment max, we need to implement it in
# the future.
if isinstance(graph, BatchedDGLGraph):
batch_num_objs = getattr(graph, batch_num_objs_attr)
feat = F.pad_packed_tensor(feat, batch_num_objs, -float('inf'))
return F.max(feat, 1)
else:
return F.max(feat, 0)
batch_num_objs = getattr(graph, batch_num_objs_attr)
feat = F.pad_packed_tensor(feat, batch_num_objs, -float('inf'))
return F.max(feat, 1)
def _softmax_on(graph, typestr, feat):
"""Internal function of applying batch-wise graph-level softmax
......@@ -782,13 +426,10 @@ def _softmax_on(graph, typestr, feat):
# TODO: the current solution pads the different graph sizes to the same,
# a more efficient way is to use segment sum/max, we need to implement
# it in the future.
if isinstance(graph, BatchedDGLGraph):
batch_num_objs = getattr(graph, batch_num_objs_attr)
feat = F.pad_packed_tensor(feat, batch_num_objs, -float('inf'))
feat = F.softmax(feat, 1)
return F.pack_padded_tensor(feat, batch_num_objs)
else:
return F.softmax(feat, 0)
batch_num_objs = getattr(graph, batch_num_objs_attr)
feat = F.pad_packed_tensor(feat, batch_num_objs, -float('inf'))
feat = F.softmax(feat, 1)
return F.pack_padded_tensor(feat, batch_num_objs)
def _broadcast_on(graph, typestr, feat_data):
"""Internal function of broadcasting features to all nodes/edges.
......@@ -808,21 +449,15 @@ def _broadcast_on(graph, typestr, feat_data):
tensor
The node/edge features tensor with shape :math:`(N, *)`.
"""
_, batch_num_objs_attr, num_objs_attr = READOUT_ON_ATTRS[typestr]
if isinstance(graph, BatchedDGLGraph):
batch_num_objs = getattr(graph, batch_num_objs_attr)
index = []
for i, num_obj in enumerate(batch_num_objs):
index.extend([i] * num_obj)
ctx = F.context(feat_data)
index = F.copy_to(F.tensor(index), ctx)
return F.gather_row(feat_data, index)
else:
num_objs = getattr(graph, num_objs_attr)()
if F.ndim(feat_data) == 1:
feat_data = F.unsqueeze(feat_data, 0)
return F.cat([feat_data] * num_objs, 0)
_, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
batch_num_objs = getattr(graph, batch_num_objs_attr)
index = []
for i, num_obj in enumerate(batch_num_objs):
index.extend([i] * num_obj)
ctx = F.context(feat_data)
index = F.copy_to(F.tensor(index), ctx)
return F.gather_row(feat_data, index)
def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
"""Internal function to take graph-wise top-k node/edge features of
......@@ -832,7 +467,7 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
If idx is set to None, the function would return top-k value of all
indices, which is equivalent to calling `th.topk(graph.ndata[feat], dim=0)`
for each example of the input graph.
for each single graph of the input batched-graph.
Parameters
---------
......@@ -854,14 +489,15 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
Returns
-------
tuple of tensors:
The first tensor returns top-k features of the given graph with
shape :math:`(K, D)`, if the input graph is a BatchedDGLGraph,
The first tensor returns top-k features of each single graph of
the input graph:
a tensor with shape :math:`(B, K, D)` would be returned, where
:math:`B` is the batch size.
The second tensor returns the top-k indices of the given graph
with shape :math:`(K)`, if the input graph is a BatchedDGLGraph,
a tensor with shape :math:`(B, K)` would be returned, where
:math:`B` is the batch size.
:math:`B` is the batch size of the input graph.
The second tensor returns the top-k indices of each single graph
of the input graph:
a tensor with shape :math:`(B, K)`(:math:`(B, K, D)` if` idx
is set to None) would be returned, where
:math:`B` is the batch size of the input graph.
Notes
-----
......@@ -870,7 +506,7 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
with all zero; in the second returned tensor, the behavior of :math:`n+1`
to :math:`k`th elements is not defined.
"""
data_attr, batch_num_objs_attr, num_objs_attr = READOUT_ON_ATTRS[typestr]
data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
data = getattr(graph, data_attr)
if F.ndim(data[feat]) > 2:
raise DGLError('The {} feature `{}` should have dimension less than or'
......@@ -878,12 +514,8 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
feat = data[feat]
hidden_size = F.shape(feat)[-1]
if isinstance(graph, BatchedDGLGraph):
batch_num_objs = getattr(graph, batch_num_objs_attr)
batch_size = len(batch_num_objs)
else:
batch_num_objs = [getattr(graph, num_objs_attr)()]
batch_size = 1
batch_num_objs = getattr(graph, batch_num_objs_attr)
batch_size = len(batch_num_objs)
length = max(max(batch_num_objs), k)
fill_val = -float('inf') if descending else float('inf')
......@@ -912,12 +544,8 @@ def _topk_on(graph, typestr, feat, k, descending=True, idx=None):
shift = F.copy_to(shift, F.context(feat))
topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift
if isinstance(graph, BatchedDGLGraph):
return F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1)),\
topk_indices
else:
return F.reshape(F.gather_row(feat_, topk_indices_), (k, -1)),\
topk_indices
return F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1)),\
topk_indices
def max_nodes(graph, feat):
......@@ -926,7 +554,7 @@ def max_nodes(graph, feat):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : str
The feature field.
......@@ -963,14 +591,14 @@ def max_nodes(graph, feat):
Max over node attribute :attr:`h` in a single graph.
>>> dgl.max_nodes(g1, 'h')
tensor([2.])
tensor([[2.]])
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no nodes,
Return a stacked tensor with an extra first dimension whose size equals
batch size of the input graph.
The i-th row of the stacked tensor contains the readout result of
the i-th graph in the batched graph. If a graph has no nodes,
a tensor filed with -inf of the same shape is returned at the
corresponding row.
"""
......@@ -982,7 +610,7 @@ def max_edges(graph, feat):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : str
The feature field.
......@@ -1021,14 +649,14 @@ def max_edges(graph, feat):
Max over edge attribute :attr:`h` in a single graph.
>>> dgl.max_edges(g1, 'h')
tensor([2.])
tensor([[2.]])
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges,
Return a stacked tensor with an extra first dimension whose size equals
batch size of the input graph.
The i-th row of the stacked tensor contains the readout result of
the i-th graph in the batched graph. If a graph has no edges,
a tensor filled with -inf of the same shape is returned at the
corresponding row.
"""
......@@ -1040,7 +668,7 @@ def softmax_nodes(graph, feat):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : str
The feature field.
......@@ -1085,8 +713,8 @@ def softmax_nodes(graph, feat):
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, the softmax is applied at
each example in the batch.
If the input graph has batch size greater then one, the softmax is applied at
each single graph in the batched graph.
"""
return _softmax_on(graph, 'nodes', feat)
......@@ -1097,7 +725,7 @@ def softmax_edges(graph, feat):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : str
The feature field.
......@@ -1144,7 +772,7 @@ def softmax_edges(graph, feat):
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, the softmax is applied at each
If the input graph has batch size greater then one, the softmax is applied at each
example in the batch.
"""
return _softmax_on(graph, 'edges', feat)
......@@ -1155,7 +783,7 @@ def broadcast_nodes(graph, feat_data):
Parameters
----------
graph : DGLGraph or BatcheDGLGraph
graph : DGLGraph
The graph.
feat_data : tensor
The feature to broadcast. Tensor shape is :math:`(*)` for single graph, and
......@@ -1205,8 +833,7 @@ def broadcast_nodes(graph, feat_data):
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, feat[i] is broadcast to the nodes
in i-th example in the batch.
feat[i] is broadcast to the nodes in i-th graph in the batched graph.
"""
return _broadcast_on(graph, 'nodes', feat_data)
......@@ -1216,7 +843,7 @@ def broadcast_edges(graph, feat_data):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat_data : tensor
The feature to broadcast. Tensor shape is :math:`(*)` for single
......@@ -1268,8 +895,7 @@ def broadcast_edges(graph, feat_data):
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, feat[i] is broadcast to
the edges in i-th example in the batch.
feat[i] is broadcast to the edges in i-th graph in the batched graph.
"""
return _broadcast_on(graph, 'edges', feat_data)
......@@ -1285,7 +911,7 @@ def topk_nodes(graph, feat, k, descending=True, idx=None):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : str
The feature field.
......@@ -1300,15 +926,15 @@ def topk_nodes(graph, feat, k, descending=True, idx=None):
Returns
-------
tuple of tensors
The first tensor returns top-k node features of the given graph
with shape :math:`(K, D)`, if the input graph is a BatchedDGLGraph,
The first tensor returns top-k node features of each single graph of
the input graph:
a tensor with shape :math:`(B, K, D)` would be returned, where
:math:`B` is the batch size.
The second tensor returns the top-k edge indices of the given
graph with shape :math:`(K)`(:math:`(K, D)` if idx is set to None),
if the input graph is a BatchedDGLGraph, a tensor with shape
:math:`(B, K)`(:math:`(B, K, D)` if` idx is set to None) would be
returned, where :math:`B` is the batch size.
:math:`B` is the batch size of the input graph.
The second tensor returns the top-k node indices of each single graph
of the input graph:
a tensor with shape :math:`(B, K)`(:math:`(B, K, D)` if` idx
is set to None) would be returned, where
:math:`B` is the batch size of the input graph.
Examples
--------
......@@ -1372,9 +998,9 @@ def topk_nodes(graph, feat, k, descending=True, idx=None):
Top-k over node attribute :attr:`h` in a single graph.
>>> dgl.topk_nodes(g1, 'h', 3)
(tensor([[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],
[0.5171, 0.6515, 0.9140, 0.7507, 0.5297],
[0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]), tensor([[[1, 0, 1, 3, 1],
(tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],
[0.5171, 0.6515, 0.9140, 0.7507, 0.5297],
[0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]]), tensor([[[1, 0, 1, 3, 1],
[3, 2, 0, 2, 2],
[2, 3, 2, 1, 3]]]))
......@@ -1400,7 +1026,7 @@ def topk_edges(graph, feat, k, descending=True, idx=None):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : str
The feature field.
......@@ -1415,15 +1041,15 @@ def topk_edges(graph, feat, k, descending=True, idx=None):
Returns
-------
tuple of tensors
The first tensor returns top-k edge features of the given graph
with shape :math:`(K, D)`, if the input graph is a BatchedDGLGraph,
The first tensor returns top-k edge features of each single graph of
the input graph:
a tensor with shape :math:`(B, K, D)` would be returned, where
:math:`B` is the batch size.
The second tensor returns the top-k edge indices of the given
graph with shape :math:`(K)`(:math:`(K, D)` if idx is set to None),
if the input graph is a BatchedDGLGraph, a tensor with shape
:math:`(B, K)`(:math:`(B, K, D)` if` idx is set to None) would be
returned, where :math:`B` is the batch size.
:math:`B` is the batch size of the input graph.
The second tensor returns the top-k edge indices of each single graph
of the input graph:
a tensor with shape :math:`(B, K)`(:math:`(B, K, D)` if` idx
is set to None) would be returned, where
:math:`B` is the batch size of the input graph.
Examples
--------
......@@ -1489,9 +1115,9 @@ def topk_edges(graph, feat, k, descending=True, idx=None):
Top-k over edge attribute :attr:`h` in a single graph.
>>> dgl.topk_edges(g1, 'h', 3)
(tensor([[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],
[0.5171, 0.6515, 0.9140, 0.7507, 0.5297],
[0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]), tensor([[[1, 0, 1, 3, 1],
(tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],
[0.5171, 0.6515, 0.9140, 0.7507, 0.5297],
[0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]]), tensor([[[1, 0, 1, 3, 1],
[3, 2, 0, 2, 2],
[2, 3, 2, 1, 3]]]))
......
"""Class for subgraph data structure."""
from __future__ import absolute_import
from .frame import Frame, FrameRef
from .graph import DGLGraph
from . import utils
from .base import DGLError
from .graph_index import map_to_subgraph_nid
class DGLSubGraph(DGLGraph):
"""The subgraph class.
There are two subgraph modes: shared and non-shared.
For the "non-shared" mode, the user needs to explicitly call
``copy_from_parent`` to copy node/edge features from its parent graph.
* If the user tries to get node/edge features before ``copy_from_parent``,
s/he will get nothing.
* If the subgraph already has its own node/edge features, ``copy_from_parent``
will override them.
* Any update on the subgraph's node/edge features will not be seen
by the parent graph. As such, the memory consumption is of the order
of the subgraph size.
* To write the subgraph's node/edge features back to parent graph. There are two options:
(1) Use ``copy_to_parent`` API to write node/edge features back.
(2) [TODO] Use ``dgl.merge`` to merge multiple subgraphs back to one parent.
The "shared" mode is currently not supported.
The subgraph is read-only on structure; graph mutation is not allowed.
Parameters
----------
parent : DGLGraph
The parent graph
sgi : SubgraphIndex
Internal subgraph data structure.
shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph.
"""
def __init__(self, parent, sgi, shared=False):
super(DGLSubGraph, self).__init__(graph_data=sgi.graph,
readonly=True)
if shared:
raise DGLError('Shared mode is not yet supported.')
self._parent = parent
self._parent_nid = sgi.induced_nodes
self._parent_eid = sgi.induced_edges
self._subgraph_index = sgi
# override APIs
def add_nodes(self, num, data=None):
"""Add nodes. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, data=None):
"""Add one edge. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, data=None):
"""Add many edges. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
@property
def parent_nid(self):
"""Get the parent node ids.
The returned tensor can be used as a map from the node id
in this subgraph to the node id in the parent graph.
Returns
-------
Tensor
The parent node id array.
"""
return self._parent_nid.tousertensor()
def _get_parent_eid(self):
# The parent eid might be lazily evaluated and thus may not
# be an index. Instead, it's a lambda function that returns
# an index.
if isinstance(self._parent_eid, utils.Index):
return self._parent_eid
else:
return self._parent_eid()
@property
def parent_eid(self):
"""Get the parent edge ids.
The returned tensor can be used as a map from the edge id
in this subgraph to the edge id in the parent graph.
Returns
-------
Tensor
The parent edge id array.
"""
return self._get_parent_eid().tousertensor()
def copy_to_parent(self, inplace=False):
"""Write node/edge features to the parent graph.
Parameters
----------
inplace : bool
If true, use inplace write (no gradient but faster)
"""
self._parent._node_frame.update_rows(
self._parent_nid, self._node_frame, inplace=inplace)
if self._parent._edge_frame.num_rows != 0:
self._parent._edge_frame.update_rows(
self._get_parent_eid(), self._edge_frame, inplace=inplace)
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
All old features will be removed.
"""
if self._parent._node_frame.num_rows != 0 and self._parent._node_frame.num_columns != 0:
self._node_frame = FrameRef(Frame(
self._parent._node_frame[self._parent_nid]))
if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0:
self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[self._get_parent_eid()]))
def map_to_subgraph_nid(self, parent_vids):
"""Map the node Ids in the parent graph to the node Ids in the subgraph.
Parameters
----------
parent_vids : list, tensor
The node ID array in the parent graph.
Returns
-------
tensor
The node ID array in the subgraph.
"""
v = map_to_subgraph_nid(self._subgraph_index, utils.toindex(parent_vids))
return v.tousertensor()
......@@ -7,12 +7,11 @@ from ._ffi.function import _init_api
from .graph import DGLGraph
from .heterograph import DGLHeteroGraph
from . import ndarray as nd
from .subgraph import DGLSubGraph
from . import backend as F
from .graph_index import from_coo
from .graph_index import _get_halo_subgraph_inner_node
from .graph_index import _get_halo_subgraph_inner_edge
from .batched_graph import BatchedDGLGraph, unbatch
from .graph import unbatch
from .convert import graph, bipartite
from . import utils
from .base import EID, NID
......@@ -250,7 +249,6 @@ def reverse(g, share_ndata=False, share_edata=False):
Notes
-----
* This function does not support :class:`~dgl.BatchedDGLGraph` objects.
* We do not dynamically update the topology of a graph once that of its reverse changes.
This can be particularly problematic when the node/edge attrs are shared. For example,
if the topology of both the original graph and its reverse get changed independently,
......@@ -307,12 +305,12 @@ def reverse(g, share_ndata=False, share_edata=False):
[2.],
[3.]])
"""
assert not isinstance(g, BatchedDGLGraph), \
'reverse is not supported for a BatchedDGLGraph object'
g_reversed = DGLGraph(multigraph=g.is_multigraph)
g_reversed.add_nodes(g.number_of_nodes())
g_edges = g.all_edges(order='eid')
g_reversed.add_edges(g_edges[1], g_edges[0])
g_reversed._batch_num_nodes = g._batch_num_nodes
g_reversed._batch_num_edges = g._batch_num_edges
if share_ndata:
g_reversed._node_frame = g._node_frame
if share_edata:
......@@ -391,17 +389,14 @@ def laplacian_lambda_max(g):
Parameters
----------
g : DGLGraph or BatchedDGLGraph
g : DGLGraph
The input graph, it should be an undirected graph.
Returns
-------
list :
* If the input g is a DGLGraph, the returned value would be
a list with one element, indicating the largest eigenvalue of g.
* If the input g is a BatchedDGLGraph, the returned value would
be a list, where the i-th item indicates the largest eigenvalue
of i-th graph in g.
Return a list, where the i-th item indicates the largest eigenvalue
of i-th graph in g.
Examples
--------
......@@ -413,11 +408,7 @@ def laplacian_lambda_max(g):
>>> dgl.laplacian_lambda_max(g)
[1.809016994374948]
"""
if isinstance(g, BatchedDGLGraph):
g_arr = unbatch(g)
else:
g_arr = [g]
g_arr = unbatch(g)
rst = []
for g_i in g_arr:
n = g_i.number_of_nodes()
......@@ -573,7 +564,7 @@ def partition_graph_with_halo(g, node_part, num_hops):
for i, subg in enumerate(subgs):
inner_node = _get_halo_subgraph_inner_node(subg)
inner_edge = _get_halo_subgraph_inner_edge(subg)
subg = DGLSubGraph(g, subg)
subg = g._create_subgraph(subg, subg.induced_nodes, subg.induced_edges)
inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack())
subg.ndata['inner_node'] = inner_node
inner_edge = F.zerocopy_from_dlpack(inner_edge.to_dlpack())
......
......@@ -124,8 +124,8 @@ def test_node_subgraph():
subig = ig.node_subgraph(utils.toindex(randv))
check_basics(subg.graph, subig.graph)
check_graph_equal(subg.graph, subig.graph)
assert F.asnumpy(map_to_subgraph_nid(subg, utils.toindex(randv1[0:10])).tousertensor()
== map_to_subgraph_nid(subig, utils.toindex(randv1[0:10])).tousertensor()).sum(0).item() == 10
assert F.asnumpy(map_to_subgraph_nid(subg.induced_nodes, utils.toindex(randv1[0:10])).tousertensor()
== map_to_subgraph_nid(subig.induced_nodes, utils.toindex(randv1[0:10])).tousertensor()).sum(0).item() == 10
# node_subgraphs
randvs = []
......
......@@ -195,7 +195,7 @@ def test_softmax_edges():
def test_broadcast_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,))
feat0 = F.randn((1, 40))
ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0)
assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth)
......@@ -204,23 +204,23 @@ def test_broadcast_nodes():
g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,))
feat2 = F.randn((40,))
feat3 = F.randn((40,))
ground_truth = F.stack(
feat1 = F.randn((1, 40))
feat2 = F.randn((1, 40))
feat3 = F.randn((1, 40))
ground_truth = F.cat(
[feat0] * g0.number_of_nodes() +\
[feat1] * g1.number_of_nodes() +\
[feat2] * g2.number_of_nodes() +\
[feat3] * g3.number_of_nodes(), 0
)
assert F.allclose(dgl.broadcast_nodes(
bg, F.stack([feat0, feat1, feat2, feat3], 0)
bg, F.cat([feat0, feat1, feat2, feat3], 0)
), ground_truth)
def test_broadcast_edges():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,))
feat0 = F.randn((1, 40))
ground_truth = F.stack([feat0] * g0.number_of_edges(), 0)
assert F.allclose(dgl.broadcast_edges(g0, feat0), ground_truth)
......@@ -229,17 +229,17 @@ def test_broadcast_edges():
g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,))
feat2 = F.randn((40,))
feat3 = F.randn((40,))
ground_truth = F.stack(
feat1 = F.randn((1, 40))
feat2 = F.randn((1, 40))
feat3 = F.randn((1, 40))
ground_truth = F.cat(
[feat0] * g0.number_of_edges() +\
[feat1] * g1.number_of_edges() +\
[feat2] * g2.number_of_edges() +\
[feat3] * g3.number_of_edges(), 0
)
assert F.allclose(dgl.broadcast_edges(
bg, F.stack([feat0, feat1, feat2, feat3], 0)
bg, F.cat([feat0, feat1, feat2, feat3], 0)
), ground_truth)
if __name__ == '__main__':
......
......@@ -41,13 +41,13 @@ def test_basics():
eid = {2, 3, 4, 5, 10, 11, 12, 13, 16}
assert set(F.zerocopy_to_numpy(sg.parent_eid)) == eid
eid = F.tensor(sg.parent_eid)
# the subgraph is empty initially
assert len(sg.ndata) == 0
assert len(sg.edata) == 0
# the data is copied after explict copy from
sg.copy_from_parent()
# the subgraph is empty initially except for NID/EID field
assert len(sg.ndata) == 1
assert len(sg.edata) == 1
# the data is copied after explict copy from
sg.copy_from_parent()
assert len(sg.ndata) == 2
assert len(sg.edata) == 2
sh = sg.ndata['h']
assert F.allclose(F.gather_row(h, F.tensor(nid)), sh)
'''
......
......@@ -328,7 +328,7 @@ def test_set2set():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph
bg = dgl.batch([g, g, g])
......@@ -346,7 +346,7 @@ def test_glob_att_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
......@@ -366,13 +366,13 @@ def test_simple_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(g, h0)
check_close(h1, F.sum(h0, 0))
check_close(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0)
check_close(h1, F.mean(h0, 0))
check_close(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0)
check_close(h1, F.max(h0, 0))
check_close(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
......
......@@ -124,7 +124,7 @@ def test_set2set():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(11))
......@@ -145,7 +145,7 @@ def test_glob_att_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
......@@ -170,13 +170,13 @@ def test_simple_pool():
max_pool = max_pool.to(ctx)
sort_pool = sort_pool.to(ctx)
h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0)
assert F.allclose(h1, F.mean(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0)
assert F.allclose(h1, F.max(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.dim() == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
......@@ -228,7 +228,7 @@ def test_set_trans():
h1 = st_enc_1(g, h0)
assert h1.shape == h0.shape
h2 = st_dec(g, h1)
assert h2.shape[0] == 200 and h2.dim() == 1
assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5))
......
......@@ -93,13 +93,13 @@ def test_simple_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0)
assert F.allclose(h1, F.mean(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0)
assert F.allclose(h1, F.max(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
......@@ -246,7 +246,7 @@ def test_glob_att_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
......
......@@ -91,12 +91,13 @@ def plot_tree(g):
plot_tree(graph.to_networkx())
#################################################################################
# You can read more about the definition of :func:`~dgl.batched_graph.batch`, or
# You can read more about the definition of :func:`~dgl.batch`, or
# skip ahead to the next step:
# .. note::
#
# **Definition**: A :class:`~dgl.batched_graph.BatchedDGLGraph` is a
# :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s.
# **Definition**: :func:`~dgl.batch` unions a list of :math:`B`
# :class:`~dgl.DGLGraph`\ s and returns a :class:`~dgl.DGLGraph` of batch
# size :math:`B`.
#
# - The union includes all the nodes,
# edges, and their features. The order of nodes, edges, and features are
......@@ -108,23 +109,16 @@ plot_tree(graph.to_networkx())
# :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph.
#
# - Therefore, performing feature transformation and message passing on
# ``BatchedDGLGraph`` is equivalent to doing those
# the batched graph is equivalent to doing those
# on all ``DGLGraph`` constituents in parallel.
#
# - Duplicate references to the same graph are
# treated as deep copies; the nodes, edges, and features are duplicated,
# and mutation on one reference does not affect the other.
# - Currently, ``BatchedDGLGraph`` is immutable in
# graph structure. You can't add
# nodes and edges to it. You need to support mutable batched graphs in
# (far) future.
# - The ``BatchedDGLGraph`` keeps track of the meta
# - The batched graph keeps track of the meta
# information of the constituents so it can be
# :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s.
#
# For more details about the :class:`~dgl.batched_graph.BatchedDGLGraph`
# module in DGL, you can click the class name.
#
# Step 2: Tree-LSTM cell with message-passing APIs
# ------------------------------------------------
#
......
......@@ -798,9 +798,8 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__
# , and it is worth explaining one more time why this is so.
#
# By batching many small graphs, DGL internally maintains a large *container*
# graph (``BatchedDGLGraph``) over which ``update_all`` propels message-passing
# on all the edges and nodes.
# By batching many small graphs, DGL parallels message passing on each individual
# graphs of a batch.
#
# With ``dgl.batch``, you merge ``g_{1}, ..., g_{N}`` into one single giant
# graph consisting of :math:`N` isolated small graphs. For example, if we
......@@ -833,7 +832,7 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# internally groups nodes with the same in-degrees and calls reduce UDF once
# for each group. Thus, batching also reduces number of calls to these UDFs.
#
# The modification of the node/edge features of a ``BatchedDGLGraph`` object
# The modification of the node/edge features of the batched graph object
# does not take effect on the features of the original small graphs, so we
# need to replace the old graph list with the new graph list
# ``g_list = dgl.unbatch(bg)``.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment