Unverified Commit 93ac29ce authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Unify DGLGraph, BatchedDGLGraph and DGLSubGraph (#1216)



* upd

* upd

* upd

* lint

* fix

* fix test

* fix

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd tutorial

* upd

* upd

* fix kg

* upd doc organization

* refresh test

* upd

* refactor doc

* fix lint
Co-authored-by: default avatarMinjie Wang <minjie.wang@nyu.edu>
parent 8874e830
......@@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .gnn import GCNLayer, GATLayer
from ...batched_graph import BatchedDGLGraph, max_nodes
from ...readout import max_nodes
from ...nn.pytorch import WeightAndSum
from ...contrib.deprecation import deprecated
......@@ -74,13 +74,13 @@ class BaseGNNClassifier(nn.Module):
self.soft_classifier = MLPBinaryClassifier(
self.g_feats, classifier_hidden_feats, n_tasks, dropout)
def forward(self, bg, feats):
def forward(self, g, feats):
"""Multi-task prediction for a batch of molecules
Parameters
----------
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M0)
Initial features for all atoms in the batch of molecules
......@@ -91,18 +91,15 @@ class BaseGNNClassifier(nn.Module):
"""
# Update atom features with GNNs
for gnn in self.gnn_layers:
feats = gnn(bg, feats)
feats = gnn(g, feats)
# Compute molecule features from atom features
h_g_sum = self.weighted_sum_readout(bg, feats)
h_g_sum = self.weighted_sum_readout(g, feats)
with bg.local_scope():
bg.ndata['h'] = feats
h_g_max = max_nodes(bg, 'h')
with g.local_scope():
g.ndata['h'] = feats
h_g_max = max_nodes(g, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
h_g_max = h_g_max.unsqueeze(0)
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
# Multi-task prediction
......
......@@ -42,13 +42,13 @@ class GCNLayer(nn.Module):
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, bg, feats):
def forward(self, g, feats):
"""Update atom representations
Parameters
----------
bg : BatchedDGLGraph
Batched DGLGraphs for processing multiple molecules in parallel
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
......@@ -58,7 +58,7 @@ class GCNLayer(nn.Module):
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size, must match out_feats in initialization
"""
new_feats = self.graph_conv(bg, feats)
new_feats = self.graph_conv(g, feats)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
......
......@@ -7,7 +7,7 @@ import torch.nn.functional as F
import rdkit.Chem as Chem
from ....batched_graph import batch, unbatch
from ....graph import batch, unbatch
from ....contrib.deprecation import deprecated
from ....data.utils import get_download_dir
......
......@@ -63,7 +63,7 @@ class ChebConv(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
......
......@@ -3,8 +3,7 @@
from mxnet import gluon, nd
from mxnet.gluon import nn
from ... import BatchedDGLGraph
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
......@@ -24,7 +23,7 @@ class SumPooling(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
......@@ -33,9 +32,8 @@ class SumPooling(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -61,7 +59,7 @@ class AvgPooling(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
......@@ -70,9 +68,8 @@ class AvgPooling(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -98,7 +95,7 @@ class MaxPooling(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
......@@ -107,9 +104,8 @@ class MaxPooling(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -139,7 +135,7 @@ class SortPooling(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
......@@ -148,9 +144,8 @@ class SortPooling(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(k * D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, k * D)`.
The output feature with shape :math:`(B, k * D)`, where
:math:`B` refers to the batch size.
"""
# Sort the feature of each node in ascending order.
with graph.local_scope():
......@@ -159,10 +154,7 @@ class SortPooling(nn.Block):
# Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k)[0].reshape(
-1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph):
return ret
else:
return ret.squeeze(axis=0)
def __repr__(self):
return 'SortPooling(k={})'.format(self.k)
......@@ -195,7 +187,7 @@ class GlobalAttentionPooling(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
......@@ -204,9 +196,8 @@ class GlobalAttentionPooling(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
......@@ -263,7 +254,7 @@ class Set2Set(nn.Block):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
......@@ -272,13 +263,10 @@ class Set2Set(nn.Block):
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
batch_size = 1
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
h = (nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context),
......@@ -288,23 +276,14 @@ class Set2Set(nn.Block):
for _ in range(self.n_iters):
q, h = self.lstm(q_star.expand_dims(axis=0), h)
q = q.reshape((batch_size, self.input_dim))
e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True)
graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r')
if readout.ndim == 1: # graph is not a BatchedDGLGraph
readout = readout.expand_dims(0)
q_star = nd.concat(q, readout, dim=-1)
if isinstance(graph, BatchedDGLGraph):
return q_star
else:
return q_star.squeeze(axis=0)
def __repr__(self):
summary = 'Set2Set('
......
......@@ -203,7 +203,7 @@ class AtomicConv(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
Topology based on which message passing is performed.
feat : Float32 tensor of shape (V, 1)
Initial node features, which are atomic numbers in the paper.
......
......@@ -68,7 +68,7 @@ class ChebConv(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
......
......@@ -4,9 +4,8 @@ import torch as th
import torch.nn as nn
import numpy as np
from ... import BatchedDGLGraph
from ...backend import pytorch as F
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
......@@ -29,7 +28,7 @@ class SumPooling(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -38,9 +37,8 @@ class SumPooling(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -62,7 +60,7 @@ class AvgPooling(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -71,9 +69,8 @@ class AvgPooling(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -95,7 +92,7 @@ class MaxPooling(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -104,9 +101,8 @@ class MaxPooling(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -132,7 +128,7 @@ class SortPooling(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
......@@ -141,9 +137,8 @@ class SortPooling(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(k * D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, k * D)`.
The output feature with shape :math:`(B, k * D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
# Sort the feature of each node in ascending order.
......@@ -152,10 +147,7 @@ class SortPooling(nn.Module):
# Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k, idx=-1)[0].view(
-1, self.k * feat.shape[-1])
if isinstance(graph, BatchedDGLGraph):
return ret
else:
return ret.squeeze(0)
class GlobalAttentionPooling(nn.Module):
......@@ -193,9 +185,8 @@ class GlobalAttentionPooling(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
......@@ -257,7 +248,7 @@ class Set2Set(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
......@@ -266,13 +257,10 @@ class Set2Set(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
batch_size = 1
if isinstance(graph, BatchedDGLGraph):
batch_size = graph.batch_size
h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
......@@ -283,23 +271,14 @@ class Set2Set(nn.Module):
for _ in range(self.n_iters):
q, h = self.lstm(q_star.unsqueeze(0), h)
q = q.view(batch_size, self.input_dim)
e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True)
graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r')
if readout.dim() == 1: # graph is not a BatchedDGLGraph
readout = readout.unsqueeze(0)
q_star = th.cat([q, readout], dim=-1)
if isinstance(graph, BatchedDGLGraph):
return q_star
else:
return q_star.squeeze(0)
def extra_repr(self):
"""Set the extra representation of the module.
......@@ -574,7 +553,7 @@ class SetTransformerEncoder(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
......@@ -585,14 +564,9 @@ class SetTransformerEncoder(nn.Module):
torch.Tensor
The output feature with shape :math:`(N, D)`.
"""
if isinstance(graph, BatchedDGLGraph):
lengths = graph.batch_num_nodes
else:
lengths = [graph.number_of_nodes()]
for layer in self.layers:
feat = layer(feat, lengths)
return feat
......@@ -640,7 +614,7 @@ class SetTransformerDecoder(nn.Module):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
......@@ -649,25 +623,15 @@ class SetTransformerDecoder(nn.Module):
Returns
-------
torch.Tensor
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
if isinstance(graph, BatchedDGLGraph):
len_pma = graph.batch_num_nodes
len_sab = [self.k] * graph.batch_size
else:
len_pma = [graph.number_of_nodes()]
len_sab = [self.k]
feat = self.pma(feat, len_pma)
for layer in self.layers:
feat = layer(feat, len_sab)
if isinstance(graph, BatchedDGLGraph):
return feat.view(graph.batch_size, self.k * self.d_model)
else:
return feat.view(self.k * self.d_model)
class WeightAndSum(nn.Module):
......@@ -686,13 +650,13 @@ class WeightAndSum(nn.Module):
nn.Sigmoid()
)
def forward(self, bg, feats):
def forward(self, g, feats):
"""Compute molecule representations out of atom representations
Parameters
----------
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules
* N is the total number of atoms in all molecules
......@@ -702,9 +666,9 @@ class WeightAndSum(nn.Module):
FloatTensor of shape (B, self.in_feats)
Representations for B molecules
"""
with bg.local_scope():
bg.ndata['h'] = feats
bg.ndata['w'] = self.atom_weighting(bg.ndata['h'])
h_g_sum = sum_nodes(bg, 'h', 'w')
with g.local_scope():
g.ndata['h'] = feats
g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(g, 'h', 'w')
return h_g_sum
......@@ -4,8 +4,7 @@ import tensorflow as tf
from tensorflow.keras import layers
from ... import BatchedDGLGraph
from ...batched_graph import sum_nodes, mean_nodes, max_nodes, \
from ...readout import sum_nodes, mean_nodes, max_nodes, \
softmax_nodes, topk_nodes
......@@ -29,7 +28,7 @@ class SumPooling(layers.Layer):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -38,9 +37,8 @@ class SumPooling(layers.Layer):
Returns
-------
tf.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -63,7 +61,7 @@ class AvgPooling(layers.Layer):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -72,9 +70,8 @@ class AvgPooling(layers.Layer):
Returns
-------
tf.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -97,7 +94,7 @@ class MaxPooling(layers.Layer):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, *)` where
......@@ -106,9 +103,8 @@ class MaxPooling(layers.Layer):
Returns
-------
tf.Tensor
The output feature with shape :math:`(*)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, *)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
......@@ -135,7 +131,7 @@ class SortPooling(layers.Layer):
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, D)` where
......@@ -144,9 +140,8 @@ class SortPooling(layers.Layer):
Returns
-------
tf.Tensor
The output feature with shape :math:`(k * D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, k * D)`.
The output feature with shape :math:`(B, k * D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
# Sort the feature of each node in ascending order.
......@@ -155,10 +150,7 @@ class SortPooling(layers.Layer):
# Sort nodes according to their last features.
ret = tf.reshape(topk_nodes(graph, 'h', self.k, idx=-1)[0], (
-1, self.k * feat.shape[-1]))
if isinstance(graph, BatchedDGLGraph):
return ret
else:
return tf.squeeze(ret, 0)
class GlobalAttentionPooling(layers.Layer):
......@@ -197,9 +189,8 @@ class GlobalAttentionPooling(layers.Layer):
Returns
-------
tf.Tensor
The output feature with shape :math:`(D)` (if
input graph is a BatchedDGLGraph, the result shape
would be :math:`(B, D)`.
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
......@@ -234,13 +225,13 @@ class WeightAndSum(layers.Layer):
layers.Activation(tf.nn.sigmoid)
)
def call(self, bg, feats):
def call(self, g, feats):
"""Compute molecule representations out of atom representations
Parameters
----------
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules
* N is the total number of atoms in all molecules
......@@ -250,9 +241,9 @@ class WeightAndSum(layers.Layer):
FloatTensor of shape (B, self.in_feats)
Representations for B molecules
"""
with bg.local_scope():
bg.ndata['h'] = feats
bg.ndata['w'] = self.atom_weighting(bg.ndata['h'])
h_g_sum = sum_nodes(bg, 'h', 'w')
with g.local_scope():
g.ndata['h'] = feats
g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(g, 'h', 'w')
return h_g_sum
"""Class for subgraph data structure."""
from __future__ import absolute_import
from .frame import Frame, FrameRef
from .graph import DGLGraph
from . import utils
from .base import DGLError
from .graph_index import map_to_subgraph_nid
class DGLSubGraph(DGLGraph):
"""The subgraph class.
There are two subgraph modes: shared and non-shared.
For the "non-shared" mode, the user needs to explicitly call
``copy_from_parent`` to copy node/edge features from its parent graph.
* If the user tries to get node/edge features before ``copy_from_parent``,
s/he will get nothing.
* If the subgraph already has its own node/edge features, ``copy_from_parent``
will override them.
* Any update on the subgraph's node/edge features will not be seen
by the parent graph. As such, the memory consumption is of the order
of the subgraph size.
* To write the subgraph's node/edge features back to parent graph. There are two options:
(1) Use ``copy_to_parent`` API to write node/edge features back.
(2) [TODO] Use ``dgl.merge`` to merge multiple subgraphs back to one parent.
The "shared" mode is currently not supported.
The subgraph is read-only on structure; graph mutation is not allowed.
Parameters
----------
parent : DGLGraph
The parent graph
sgi : SubgraphIndex
Internal subgraph data structure.
shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph.
"""
def __init__(self, parent, sgi, shared=False):
super(DGLSubGraph, self).__init__(graph_data=sgi.graph,
readonly=True)
if shared:
raise DGLError('Shared mode is not yet supported.')
self._parent = parent
self._parent_nid = sgi.induced_nodes
self._parent_eid = sgi.induced_edges
self._subgraph_index = sgi
# override APIs
def add_nodes(self, num, data=None):
"""Add nodes. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, data=None):
"""Add one edge. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, data=None):
"""Add many edges. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')
@property
def parent_nid(self):
"""Get the parent node ids.
The returned tensor can be used as a map from the node id
in this subgraph to the node id in the parent graph.
Returns
-------
Tensor
The parent node id array.
"""
return self._parent_nid.tousertensor()
def _get_parent_eid(self):
# The parent eid might be lazily evaluated and thus may not
# be an index. Instead, it's a lambda function that returns
# an index.
if isinstance(self._parent_eid, utils.Index):
return self._parent_eid
else:
return self._parent_eid()
@property
def parent_eid(self):
"""Get the parent edge ids.
The returned tensor can be used as a map from the edge id
in this subgraph to the edge id in the parent graph.
Returns
-------
Tensor
The parent edge id array.
"""
return self._get_parent_eid().tousertensor()
def copy_to_parent(self, inplace=False):
"""Write node/edge features to the parent graph.
Parameters
----------
inplace : bool
If true, use inplace write (no gradient but faster)
"""
self._parent._node_frame.update_rows(
self._parent_nid, self._node_frame, inplace=inplace)
if self._parent._edge_frame.num_rows != 0:
self._parent._edge_frame.update_rows(
self._get_parent_eid(), self._edge_frame, inplace=inplace)
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
All old features will be removed.
"""
if self._parent._node_frame.num_rows != 0 and self._parent._node_frame.num_columns != 0:
self._node_frame = FrameRef(Frame(
self._parent._node_frame[self._parent_nid]))
if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0:
self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[self._get_parent_eid()]))
def map_to_subgraph_nid(self, parent_vids):
"""Map the node Ids in the parent graph to the node Ids in the subgraph.
Parameters
----------
parent_vids : list, tensor
The node ID array in the parent graph.
Returns
-------
tensor
The node ID array in the subgraph.
"""
v = map_to_subgraph_nid(self._subgraph_index, utils.toindex(parent_vids))
return v.tousertensor()
......@@ -7,12 +7,11 @@ from ._ffi.function import _init_api
from .graph import DGLGraph
from .heterograph import DGLHeteroGraph
from . import ndarray as nd
from .subgraph import DGLSubGraph
from . import backend as F
from .graph_index import from_coo
from .graph_index import _get_halo_subgraph_inner_node
from .graph_index import _get_halo_subgraph_inner_edge
from .batched_graph import BatchedDGLGraph, unbatch
from .graph import unbatch
from .convert import graph, bipartite
from . import utils
from .base import EID, NID
......@@ -250,7 +249,6 @@ def reverse(g, share_ndata=False, share_edata=False):
Notes
-----
* This function does not support :class:`~dgl.BatchedDGLGraph` objects.
* We do not dynamically update the topology of a graph once that of its reverse changes.
This can be particularly problematic when the node/edge attrs are shared. For example,
if the topology of both the original graph and its reverse get changed independently,
......@@ -307,12 +305,12 @@ def reverse(g, share_ndata=False, share_edata=False):
[2.],
[3.]])
"""
assert not isinstance(g, BatchedDGLGraph), \
'reverse is not supported for a BatchedDGLGraph object'
g_reversed = DGLGraph(multigraph=g.is_multigraph)
g_reversed.add_nodes(g.number_of_nodes())
g_edges = g.all_edges(order='eid')
g_reversed.add_edges(g_edges[1], g_edges[0])
g_reversed._batch_num_nodes = g._batch_num_nodes
g_reversed._batch_num_edges = g._batch_num_edges
if share_ndata:
g_reversed._node_frame = g._node_frame
if share_edata:
......@@ -391,16 +389,13 @@ def laplacian_lambda_max(g):
Parameters
----------
g : DGLGraph or BatchedDGLGraph
g : DGLGraph
The input graph, it should be an undirected graph.
Returns
-------
list :
* If the input g is a DGLGraph, the returned value would be
a list with one element, indicating the largest eigenvalue of g.
* If the input g is a BatchedDGLGraph, the returned value would
be a list, where the i-th item indicates the largest eigenvalue
Return a list, where the i-th item indicates the largest eigenvalue
of i-th graph in g.
Examples
......@@ -413,11 +408,7 @@ def laplacian_lambda_max(g):
>>> dgl.laplacian_lambda_max(g)
[1.809016994374948]
"""
if isinstance(g, BatchedDGLGraph):
g_arr = unbatch(g)
else:
g_arr = [g]
rst = []
for g_i in g_arr:
n = g_i.number_of_nodes()
......@@ -573,7 +564,7 @@ def partition_graph_with_halo(g, node_part, num_hops):
for i, subg in enumerate(subgs):
inner_node = _get_halo_subgraph_inner_node(subg)
inner_edge = _get_halo_subgraph_inner_edge(subg)
subg = DGLSubGraph(g, subg)
subg = g._create_subgraph(subg, subg.induced_nodes, subg.induced_edges)
inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack())
subg.ndata['inner_node'] = inner_node
inner_edge = F.zerocopy_from_dlpack(inner_edge.to_dlpack())
......
......@@ -124,8 +124,8 @@ def test_node_subgraph():
subig = ig.node_subgraph(utils.toindex(randv))
check_basics(subg.graph, subig.graph)
check_graph_equal(subg.graph, subig.graph)
assert F.asnumpy(map_to_subgraph_nid(subg, utils.toindex(randv1[0:10])).tousertensor()
== map_to_subgraph_nid(subig, utils.toindex(randv1[0:10])).tousertensor()).sum(0).item() == 10
assert F.asnumpy(map_to_subgraph_nid(subg.induced_nodes, utils.toindex(randv1[0:10])).tousertensor()
== map_to_subgraph_nid(subig.induced_nodes, utils.toindex(randv1[0:10])).tousertensor()).sum(0).item() == 10
# node_subgraphs
randvs = []
......
......@@ -195,7 +195,7 @@ def test_softmax_edges():
def test_broadcast_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,))
feat0 = F.randn((1, 40))
ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0)
assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth)
......@@ -204,23 +204,23 @@ def test_broadcast_nodes():
g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,))
feat2 = F.randn((40,))
feat3 = F.randn((40,))
ground_truth = F.stack(
feat1 = F.randn((1, 40))
feat2 = F.randn((1, 40))
feat3 = F.randn((1, 40))
ground_truth = F.cat(
[feat0] * g0.number_of_nodes() +\
[feat1] * g1.number_of_nodes() +\
[feat2] * g2.number_of_nodes() +\
[feat3] * g3.number_of_nodes(), 0
)
assert F.allclose(dgl.broadcast_nodes(
bg, F.stack([feat0, feat1, feat2, feat3], 0)
bg, F.cat([feat0, feat1, feat2, feat3], 0)
), ground_truth)
def test_broadcast_edges():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((40,))
feat0 = F.randn((1, 40))
ground_truth = F.stack([feat0] * g0.number_of_edges(), 0)
assert F.allclose(dgl.broadcast_edges(g0, feat0), ground_truth)
......@@ -229,17 +229,17 @@ def test_broadcast_edges():
g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((40,))
feat2 = F.randn((40,))
feat3 = F.randn((40,))
ground_truth = F.stack(
feat1 = F.randn((1, 40))
feat2 = F.randn((1, 40))
feat3 = F.randn((1, 40))
ground_truth = F.cat(
[feat0] * g0.number_of_edges() +\
[feat1] * g1.number_of_edges() +\
[feat2] * g2.number_of_edges() +\
[feat3] * g3.number_of_edges(), 0
)
assert F.allclose(dgl.broadcast_edges(
bg, F.stack([feat0, feat1, feat2, feat3], 0)
bg, F.cat([feat0, feat1, feat2, feat3], 0)
), ground_truth)
if __name__ == '__main__':
......
......@@ -41,13 +41,13 @@ def test_basics():
eid = {2, 3, 4, 5, 10, 11, 12, 13, 16}
assert set(F.zerocopy_to_numpy(sg.parent_eid)) == eid
eid = F.tensor(sg.parent_eid)
# the subgraph is empty initially
assert len(sg.ndata) == 0
assert len(sg.edata) == 0
# the data is copied after explict copy from
sg.copy_from_parent()
# the subgraph is empty initially except for NID/EID field
assert len(sg.ndata) == 1
assert len(sg.edata) == 1
# the data is copied after explict copy from
sg.copy_from_parent()
assert len(sg.ndata) == 2
assert len(sg.edata) == 2
sh = sg.ndata['h']
assert F.allclose(F.gather_row(h, F.tensor(nid)), sh)
'''
......
......@@ -328,7 +328,7 @@ def test_set2set():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph
bg = dgl.batch([g, g, g])
......@@ -346,7 +346,7 @@ def test_glob_att_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
......@@ -366,13 +366,13 @@ def test_simple_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(g, h0)
check_close(h1, F.sum(h0, 0))
check_close(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0)
check_close(h1, F.mean(h0, 0))
check_close(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0)
check_close(h1, F.max(h0, 0))
check_close(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
......
......@@ -124,7 +124,7 @@ def test_set2set():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(11))
......@@ -145,7 +145,7 @@ def test_glob_att_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
......@@ -170,13 +170,13 @@ def test_simple_pool():
max_pool = max_pool.to(ctx)
sort_pool = sort_pool.to(ctx)
h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0)
assert F.allclose(h1, F.mean(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0)
assert F.allclose(h1, F.max(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.dim() == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
......@@ -228,7 +228,7 @@ def test_set_trans():
h1 = st_enc_1(g, h0)
assert h1.shape == h0.shape
h2 = st_dec(g, h1)
assert h2.shape[0] == 200 and h2.dim() == 1
assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(5))
......
......@@ -93,13 +93,13 @@ def test_simple_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
h1 = avg_pool(g, h0)
assert F.allclose(h1, F.mean(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
h1 = max_pool(g, h0)
assert F.allclose(h1, F.max(h0, 0))
assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
......@@ -246,7 +246,7 @@ def test_glob_att_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
......
......@@ -91,12 +91,13 @@ def plot_tree(g):
plot_tree(graph.to_networkx())
#################################################################################
# You can read more about the definition of :func:`~dgl.batched_graph.batch`, or
# You can read more about the definition of :func:`~dgl.batch`, or
# skip ahead to the next step:
# .. note::
#
# **Definition**: A :class:`~dgl.batched_graph.BatchedDGLGraph` is a
# :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s.
# **Definition**: :func:`~dgl.batch` unions a list of :math:`B`
# :class:`~dgl.DGLGraph`\ s and returns a :class:`~dgl.DGLGraph` of batch
# size :math:`B`.
#
# - The union includes all the nodes,
# edges, and their features. The order of nodes, edges, and features are
......@@ -108,23 +109,16 @@ plot_tree(graph.to_networkx())
# :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph.
#
# - Therefore, performing feature transformation and message passing on
# ``BatchedDGLGraph`` is equivalent to doing those
# the batched graph is equivalent to doing those
# on all ``DGLGraph`` constituents in parallel.
#
# - Duplicate references to the same graph are
# treated as deep copies; the nodes, edges, and features are duplicated,
# and mutation on one reference does not affect the other.
# - Currently, ``BatchedDGLGraph`` is immutable in
# graph structure. You can't add
# nodes and edges to it. You need to support mutable batched graphs in
# (far) future.
# - The ``BatchedDGLGraph`` keeps track of the meta
# - The batched graph keeps track of the meta
# information of the constituents so it can be
# :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s.
#
# For more details about the :class:`~dgl.batched_graph.BatchedDGLGraph`
# module in DGL, you can click the class name.
#
# Step 2: Tree-LSTM cell with message-passing APIs
# ------------------------------------------------
#
......
......@@ -798,9 +798,8 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__
# , and it is worth explaining one more time why this is so.
#
# By batching many small graphs, DGL internally maintains a large *container*
# graph (``BatchedDGLGraph``) over which ``update_all`` propels message-passing
# on all the edges and nodes.
# By batching many small graphs, DGL parallels message passing on each individual
# graphs of a batch.
#
# With ``dgl.batch``, you merge ``g_{1}, ..., g_{N}`` into one single giant
# graph consisting of :math:`N` isolated small graphs. For example, if we
......@@ -833,7 +832,7 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# internally groups nodes with the same in-degrees and calls reduce UDF once
# for each group. Thus, batching also reduces number of calls to these UDFs.
#
# The modification of the node/edge features of a ``BatchedDGLGraph`` object
# The modification of the node/edge features of the batched graph object
# does not take effect on the features of the original small graphs, so we
# need to replace the old graph list with the new graph list
# ``g_list = dgl.unbatch(bg)``.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment