Unverified Commit 93ac29ce authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Unify DGLGraph, BatchedDGLGraph and DGLSubGraph (#1216)



* upd

* upd

* upd

* lint

* fix

* fix test

* fix

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd tutorial

* upd

* upd

* fix kg

* upd doc organization

* refresh test

* upd

* refactor doc

* fix lint
Co-authored-by: default avatarMinjie Wang <minjie.wang@nyu.edu>
parent 8874e830
......@@ -8,6 +8,8 @@ import sys
import pickle
import time
from dgl.base import NID, EID
def SoftRelationPartition(edges, n, threshold=0.05):
"""This partitions a list of edges to n partitions according to their
relation types. For any relation with number of edges larger than the
......@@ -359,7 +361,7 @@ class TrainDataset(object):
return_false_neg=False)
class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
class ChunkNegEdgeSubgraph(dgl.DGLGraph):
"""Wrapper for negative graph
Parameters
......@@ -378,7 +380,11 @@ class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
"""
def __init__(self, subg, num_chunks, chunk_size,
neg_sample_size, neg_head):
super(ChunkNegEdgeSubgraph, self).__init__(subg._parent, subg.sgi)
super(ChunkNegEdgeSubgraph, self).__init__(graph_data=subg.sgi.graph,
readonly=True,
parent=subg._parent)
self.ndata[NID] = subg.sgi.induced_nodes.tousertensor()
self.edata[EID] = subg.sgi.induced_edges.tousertensor()
self.subg = subg
self.num_chunks = num_chunks
self.chunk_size = chunk_size
......
"""MPNN"""
import torch.nn as nn
from dgl import BatchedDGLGraph
from dgl.nn.pytorch import Set2Set
from ..gnn import MPNNGNN
......@@ -77,6 +76,4 @@ class MPNNPredictor(nn.Module):
"""
node_feats = self.gnn(g, node_feats, edge_feats)
graph_feats = self.readout(g, node_feats)
if not isinstance(g, BatchedDGLGraph):
graph_feats = graph_feats.unsqueeze(0)
return self.predict(graph_feats)
......@@ -4,8 +4,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import BatchedDGLGraph
__all__ = ['AttentiveFPReadout']
class GlobalPool(nn.Module):
......@@ -59,10 +57,7 @@ class GlobalPool(nn.Module):
g.ndata['a'] = dgl.softmax_nodes(g, 'z')
g.ndata['hv'] = self.project_nodes(node_feats)
if isinstance(g, BatchedDGLGraph):
g_repr = dgl.sum_nodes(g, 'hv', 'a')
else:
g_repr = dgl.sum_nodes(g, 'hv', 'a').unsqueeze(0)
g_repr = dgl.sum_nodes(g, 'hv', 'a')
context = F.elu(g_repr)
if get_node_weight:
......@@ -121,9 +116,6 @@ class AttentiveFPReadout(nn.Module):
g.ndata['hv'] = node_feats
g_feats = dgl.sum_nodes(g, 'hv')
if not isinstance(g, BatchedDGLGraph):
g_feats = g_feats.unsqueeze(0)
if get_node_weight:
node_weights = []
......
......@@ -2,8 +2,6 @@
import dgl
import torch.nn as nn
from dgl import BatchedDGLGraph
__all__ = ['MLPNodeReadout']
class MLPNodeReadout(nn.Module):
......@@ -64,6 +62,4 @@ class MLPNodeReadout(nn.Module):
elif self.mode == 'sum':
graph_feats = dgl.sum_nodes(g, 'h')
if not isinstance(g, BatchedDGLGraph):
graph_feats = graph_feats.unsqueeze(0)
return graph_feats
......@@ -3,7 +3,6 @@ import dgl
import torch
import torch.nn as nn
from dgl import BatchedDGLGraph
from dgl.nn.pytorch import WeightAndSum
__all__ = ['WeightedSumAndMax']
......@@ -45,10 +44,5 @@ class WeightedSumAndMax(nn.Module):
with bg.local_scope():
bg.ndata['h'] = feats
h_g_max = dgl.max_nodes(bg, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
h_g_max = h_g_max.unsqueeze(0)
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
return h_g
......@@ -93,8 +93,8 @@ def collate_molgraphs(data):
-------
smiles : list
List of smiles
bg : BatchedDGLGraph
Batched DGLGraphs
bg : DGLGraph
The batched DGLGraph.
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
......
......@@ -45,6 +45,24 @@ Querying graph structure
DGLGraph.out_degree
DGLGraph.out_degrees
Querying batch summary
----------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.batch_size
DGLGraph.batch_num_nodes
DGLGraph.batch_num_edges
Querying sub-graph/parent-graph belonging information
-----------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.parent
Removing nodes and edges
------------------------
......@@ -66,6 +84,8 @@ Transforming graph
DGLGraph.line_graph
DGLGraph.reverse
DGLGraph.readonly
DGLGraph.flatten
DGLGraph.detach_parent
Converting from/to other format
-------------------------------
......@@ -121,3 +141,29 @@ Computing with DGLGraph
DGLGraph.filter_nodes
DGLGraph.filter_edges
DGLGraph.to
Batch and Unbatch
-------------------
.. autosummary::
:toctree: ../../generated/
batch
unbatch
Mapping between subgraph and parent graph
-----------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.parent_nid
DGLGraph.parent_eid
DGLGraph.map_to_subgraph_nid
Synchronize features between subgraph and parent graph
------------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.copy_from_parent
DGLGraph.copy_to_parent
......@@ -6,7 +6,7 @@ API Reference
graph
heterograph
batch
readout
batch_heterograph
nn
init
......@@ -17,7 +17,6 @@ API Reference
sampler
data
transform
subgraph
graph_store
nodeflow
random
......
.. _apibatch:
dgl.batched_graph
dgl.readout
==================================================
.. currentmodule:: dgl
.. autoclass:: BatchedDGLGraph
Merge and decompose
-------------------
.. autosummary::
:toctree: ../../generated/
batch
unbatch
Query batch summary
----------------------
.. autosummary::
:toctree: ../../generated/
BatchedDGLGraph.batch_size
BatchedDGLGraph.batch_num_nodes
BatchedDGLGraph.batch_num_edges
Graph Readout
-------------
......
.. _apisubgraph:
dgl.subgraph
================================================
.. currentmodule:: dgl.subgraph
.. autoclass:: DGLSubGraph
Mapping between subgraph and parent graph
-----------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLSubGraph.parent_nid
DGLSubGraph.parent_eid
DGLSubGraph.map_to_subgraph_nid
Synchronize features between subgraph and parent graph
------------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLSubGraph.copy_from_parent
DGLSubGraph.copy_to_parent
......@@ -101,12 +101,11 @@ a useful manual for in-depth developers.
api/python/graph
api/python/heterograph
api/python/batch
api/python/readout
api/python/batch_heterograph
api/python/nn
api/python/function
api/python/udf
api/python/subgraph
api/python/traversal
api/python/propagate
api/python/transform
......
......@@ -67,7 +67,7 @@ def load_data(args):
test_dataset = PPIDataset('test')
PPIDataType = namedtuple('PPIDataset', ['train_mask', 'test_mask',
'val_mask', 'features', 'labels', 'num_labels', 'graph'])
G = dgl.BatchedDGLGraph(
G = dgl.batch(
[train_dataset.graph, val_dataset.graph, test_dataset.graph], edge_attrs=None, node_attrs=None)
G = G.to_networkx()
# hack to dodge the potential bugs of to_networkx
......
......@@ -226,8 +226,8 @@ def collate_molgraphs(data):
-------
smiles : list
List of smiles
bg : BatchedDGLGraph
Batched DGLGraphs
bg : DGLGraph
The batched DGLGraph.
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
......
......@@ -19,10 +19,10 @@ from ._ffi.function import register_func, get_global_func, list_global_func_name
from ._ffi.base import DGLError, __version__
from .base import ALL, NTYPE, NID, ETYPE, EID
from .batched_graph import *
from .readout import *
from .batched_heterograph import *
from .convert import *
from .graph import DGLGraph
from .graph import DGLGraph, batch, unbatch
from .generators import *
from .heterograph import DGLHeteroGraph
from .nodeflow import *
......
......@@ -12,7 +12,8 @@ from ..._ffi.ndarray import empty
from ... import utils
from ...nodeflow import NodeFlow
from ... import backend as F
from ... import subgraph
from ...graph import DGLGraph
from ...base import NID, EID
try:
import Queue as queue
......@@ -431,13 +432,17 @@ class LayerSampler(NodeFlowSampler):
nflows = [NodeFlow(self.g, obj) for obj in nfobjs]
return nflows
class EdgeSubgraph(subgraph.DGLSubGraph):
class EdgeSubgraph(DGLGraph):
''' The subgraph sampled from an edge sampler.
A user can access the head nodes and tail nodes of the subgraph directly.
'''
def __init__(self, parent, sgi, neg):
super(EdgeSubgraph, self).__init__(parent, sgi)
super(EdgeSubgraph, self).__init__(graph_data=sgi.graph,
readonly=True,
parent=parent)
self.ndata[NID] = sgi.induced_nodes.tousertensor()
self.edata[EID] = sgi.induced_edges.tousertensor()
self.sgi = sgi
self.neg = neg
self.head = None
......@@ -735,7 +740,9 @@ class EdgeSampler(object):
if self._negative_mode == "":
# If no negative subgraphs.
return [subgraph.DGLSubGraph(self.g, subg) for subg in subgs]
return [self.g._create_subgraph(subg,
subg.induced_nodes,
subg.induced_edges) for subg in subgs]
else:
rets = []
assert len(subgs) % 2 == 0
......
"""For Graph Serialization"""
from __future__ import absolute_import
from ..graph import DGLGraph
from ..batched_graph import BatchedDGLGraph
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F
......@@ -29,7 +28,7 @@ class GraphData(ObjectBase):
@staticmethod
def create(g: DGLGraph):
"""Create GraphData"""
assert not isinstance(g, BatchedDGLGraph), "BatchedDGLGraph is not supported for serialization"
assert g.batch_size == 1, "BatchedDGLGraph is not supported for serialization"
ghandle = g._graph
if len(g.ndata) != 0:
node_tensors = dict()
......
"""Dataset for stochastic block model."""
import math
import os
import pickle
import random
import numpy as np
import numpy.random as npr
import scipy as sp
import networkx as nx
from ..batched_graph import batch
from ..graph import DGLGraph
from ..graph import DGLGraph, batch
from ..utils import Index
def sbm(n_blocks, block_size, p, q, rng=None):
......
......@@ -3,10 +3,11 @@ from __future__ import absolute_import
from collections import defaultdict
from contextlib import contextmanager
from typing import Iterable
import networkx as nx
import dgl
from .base import ALL, is_all, DGLError, dgl_warning
from .base import ALL, NID, EID, is_all, DGLError, dgl_warning
from . import backend as F
from . import init
from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
......@@ -16,7 +17,7 @@ from . import utils
from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch
__all__ = ['DGLGraph']
__all__ = ['DGLGraph', 'batch', 'unbatch']
class DGLBaseGraph(object):
"""Base graph class.
......@@ -734,6 +735,24 @@ class DGLBaseGraph(object):
v = utils.toindex(v)
return self._graph.out_degrees(v).tousertensor()
def mutation(func):
"""A decorator to decorate functions that might change graph structure."""
def inner(g, *args, **kwargs):
if g.is_readonly:
raise DGLError("Readonly graph. Mutation is not allowed.")
if g.batch_size > 1:
dgl_warning("The graph has batch_size > 1, and mutation would break"
" batching related properties, call `flatten` to remove"
" batching information of the graph.")
if g._parent is not None:
dgl_warning("The graph is a subgraph of a parent graph, and mutation"
" would break subgraph related properties, call `detach"
"_parent` to remove its connection with its parent.")
func(g, *args, **kwargs)
return inner
class DGLGraph(DGLBaseGraph):
"""Base graph class.
......@@ -902,7 +921,10 @@ class DGLGraph(DGLBaseGraph):
edge_frame=None,
multigraph=None,
readonly=False,
sort_csr=False):
sort_csr=False,
batch_num_nodes=None,
batch_num_edges=None,
parent=None):
# graph
if isinstance(graph_data, DGLGraph):
gidx = graph_data._graph
......@@ -936,6 +958,22 @@ class DGLGraph(DGLBaseGraph):
self._apply_node_func = None
self._apply_edge_func = None
# batched graph
self._batch_num_nodes = batch_num_nodes
self._batch_num_edges = batch_num_edges
# set parent if the graph is a induced subgraph.
self._parent = parent
def _create_subgraph(self, sgi, induced_nodes, induced_edges):
"""Internal function to create a subgraph from index."""
subg = DGLGraph(graph_data=sgi.graph,
readonly=True,
parent=self)
subg.ndata[NID] = induced_nodes.tousertensor()
subg.edata[EID] = induced_edges.tousertensor()
return subg
def _get_msg_index(self):
if self._msg_index is None:
self._msg_index = utils.zero_index(size=self.number_of_edges())
......@@ -944,6 +982,7 @@ class DGLGraph(DGLBaseGraph):
def _set_msg_index(self, index):
self._msg_index = index
@mutation
def add_nodes(self, num, data=None):
"""Add multiple new nodes.
......@@ -995,6 +1034,7 @@ class DGLGraph(DGLBaseGraph):
else:
self._node_frame.append(data)
@mutation
def add_edge(self, u, v, data=None):
"""Add one new edge between u and v.
......@@ -1053,6 +1093,7 @@ class DGLGraph(DGLBaseGraph):
self._msg_index = self._msg_index.append_zeros(1)
self._msg_frame.add_rows(1)
@mutation
def add_edges(self, u, v, data=None):
"""Add multiple edges for list of source nodes u and destination nodes
v. A single edge is added between every pair of ``u[i]`` and ``v[i]``.
......@@ -1114,6 +1155,7 @@ class DGLGraph(DGLBaseGraph):
self._msg_index = self._msg_index.append_zeros(num)
self._msg_frame.add_rows(num)
@mutation
def remove_nodes(self, vids):
"""Remove multiple nodes, edges that have connection with these nodes would also be removed.
......@@ -1163,8 +1205,6 @@ class DGLGraph(DGLBaseGraph):
add_edges
remove_edges
"""
if self.is_readonly:
raise DGLError("remove_nodes is not supported by read-only graph.")
induced_nodes = utils.set_diff(utils.toindex(self.nodes()), utils.toindex(vids))
sgi = self._graph.node_subgraph(induced_nodes)
......@@ -1180,6 +1220,7 @@ class DGLGraph(DGLBaseGraph):
self._graph = sgi.graph
@mutation
def remove_edges(self, eids):
"""Remove multiple edges.
......@@ -1226,8 +1267,6 @@ class DGLGraph(DGLBaseGraph):
add_edges
remove_nodes
"""
if self.is_readonly:
raise DGLError("remove_edges is not supported by read-only graph.")
induced_edges = utils.set_diff(
utils.toindex(range(self.number_of_edges())), utils.toindex(eids))
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=True)
......@@ -1244,6 +1283,112 @@ class DGLGraph(DGLBaseGraph):
self._graph = sgi.graph
@property
def parent_nid(self):
"""Get the parent node ids.
The returned tensor can be used as a map from the node id
in this subgraph to the node id in the parent graph.
Returns
-------
Tensor
The parent node id array.
"""
if self._parent is None:
raise DGLError("We only support parent_nid for subgraphs.")
return self.ndata[NID]
@property
def parent_eid(self):
"""Get the parent edge ids.
The returned tensor can be used as a map from the edge id
in this subgraph to the edge id in the parent graph.
Returns
-------
Tensor
The parent edge id array.
"""
if self._parent is None:
raise DGLError("We only support parent_eid for subgraphs.")
return self.edata[EID]
def copy_to_parent(self, inplace=False):
"""Write node/edge features to the parent graph.
Parameters
----------
inplace : bool
If true, use inplace write (no gradient but faster)
"""
if self._parent is None:
raise DGLError("We only support copy_to_parent for subgraphs.")
nids = self.ndata.pop(NID)
eids = self.edata.pop(EID)
self._parent._node_frame.update_rows(
utils.toindex(nids), self._node_frame, inplace=inplace)
if self._parent._edge_frame.num_rows != 0:
self._parent._edge_frame.update_rows(
utils.toindex(eids), self._edge_frame, inplace=inplace)
self.ndata[NID] = nids
self.edata[EID] = eids
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
All old features will be removed.
"""
if self._parent is None:
raise DGLError("We only support copy_from_parent for subgraphs.")
nids = self.ndata[NID]
eids = self.edata[EID]
if self._parent._node_frame.num_rows != 0 and self._parent._node_frame.num_columns != 0:
self._node_frame = FrameRef(Frame(
self._parent._node_frame[utils.toindex(nids)]))
if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0:
self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[utils.toindex(eids)]))
self.ndata[NID] = nids
self.edata[NID] = eids
def map_to_subgraph_nid(self, parent_vids):
"""Map the node Ids in the parent graph to the node Ids in the subgraph.
Parameters
----------
parent_vids : list, tensor
The node ID array in the parent graph.
Returns
-------
tensor
The node ID array in the subgraph.
"""
if self._parent is None:
raise DGLError("We only support map_to_subgraph_nid for subgraphs.")
v = graph_index.map_to_subgraph_nid(
utils.toindex(self.ndata[NID]), utils.toindex(parent_vids))
return v.tousertensor()
def flatten(self):
"""Remove all batching information of the graph, and regard the current
graph as an independent graph rather then a batched graph.
Graph topology and attributes would not be affected.
"""
self._batch_num_nodes = None
self._batch_num_edges = None
def detach_parent(self):
"""Detach the current graph from its parent, and regard the current graph
as an independent graph rather then a subgraph.
Graph topology and attributes would not be affected.
"""
self._parent = None
self.ndata.pop(NID)
self.edata.pop(EID)
def clear(self):
"""Remove all nodes and edges, as well as their features, from the
graph.
......@@ -1710,6 +1855,53 @@ class DGLGraph(DGLBaseGraph):
"""
return self.edges[:].data
@property
def batch_size(self):
"""Number of graphs in this batch.
Returns
-------
int
Number of graphs in this batch."""
return 1 if self.batch_num_nodes is None else len(self.batch_num_nodes)
@property
def batch_num_nodes(self):
"""Number of nodes of each graph in this batch.
Returns
-------
list
Number of nodes of each graph in this batch."""
if self._batch_num_nodes is None:
return [self.number_of_nodes()]
else:
return self._batch_num_nodes
@property
def batch_num_edges(self):
"""Number of edges of each graph in this batch.
Returns
-------
list
Number of edges of each graph in this batch."""
if self._batch_num_edges is None:
return [self.number_of_edges()]
else:
return self._batch_num_edges
@property
def parent(self):
"""If current graph is a induced subgraph of a parent graph, return
its parent graph, else return None.
Returns
-------
DGLGraph or None
The parent graph of current graph.
"""
return self._parent
def init_ndata(self, ndata_name, shape, dtype, ctx=F.cpu()):
"""Create node embedding.
......@@ -2914,7 +3106,7 @@ class DGLGraph(DGLBaseGraph):
Returns
-------
G : DGLSubGraph
G : DGLGraph
The subgraph.
The nodes are relabeled so that node `i` in the subgraph is mapped
to node `nodes[i]` in the original graph.
......@@ -2942,14 +3134,12 @@ class DGLGraph(DGLBaseGraph):
See Also
--------
DGLSubGraph
subgraphs
edge_subgraph
"""
from . import subgraph
induced_nodes = utils.toindex(nodes)
sgi = self._graph.node_subgraph(induced_nodes)
return subgraph.DGLSubGraph(self, sgi)
return self._create_subgraph(sgi, sgi.induced_nodes, sgi.induced_edges)
def subgraphs(self, nodes):
"""Return a list of subgraphs, each induced in the corresponding given
......@@ -2966,18 +3156,17 @@ class DGLGraph(DGLBaseGraph):
Returns
-------
G : A list of DGLSubGraph
G : A list of DGLGraph
The subgraphs.
See Also
--------
DGLSubGraph
subgraph
"""
from . import subgraph
induced_nodes = [utils.toindex(n) for n in nodes]
sgis = self._graph.node_subgraphs(induced_nodes)
return [subgraph.DGLSubGraph(self, sgi) for sgi in sgis]
return [self._create_subgraph(
sgi, sgi.induced_nodes, sgi.induced_edges) for sgi in sgis]
def edge_subgraph(self, edges, preserve_nodes=False):
"""Return the subgraph induced on given edges.
......@@ -2994,7 +3183,7 @@ class DGLGraph(DGLBaseGraph):
Returns
-------
G : DGLSubGraph
G : DGLGraph
The subgraph.
The edges are relabeled so that edge `i` in the subgraph is mapped
to edge `edges[i]` in the original graph.
......@@ -3031,13 +3220,11 @@ class DGLGraph(DGLBaseGraph):
See Also
--------
DGLSubGraph
subgraph
"""
from . import subgraph
induced_edges = utils.toindex(edges)
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=preserve_nodes)
return subgraph.DGLSubGraph(self, sgi)
return self._create_subgraph(sgi, sgi.induced_nodes, sgi.induced_edges)
def adjacency_matrix_scipy(self, transpose=None, fmt='csr', return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph.
......@@ -3427,9 +3614,13 @@ class DGLGraph(DGLBaseGraph):
# otherwise the default initializer will be used.
sync_frame_initializer(local_node_frame._frame, self._node_frame._frame)
sync_frame_initializer(local_edge_frame._frame, self._edge_frame._frame)
return DGLGraph(self._graph,
local_node_frame,
local_edge_frame)
return DGLGraph(graph_data=self._graph,
node_frame=local_node_frame,
edge_frame=local_edge_frame,
readonly=self.is_readonly,
batch_num_nodes=self.batch_num_nodes,
batch_num_edges=self.batch_num_edges,
parent=self._parent)
@contextmanager
def local_scope(self):
......@@ -3489,6 +3680,178 @@ class DGLGraph(DGLBaseGraph):
self._node_frame = old_nframe
self._edge_frame = old_eframe
############################################################
# Batch/Unbatch APIs
############################################################
def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
"""Batch a collection of :class:`~dgl.DGLGraph` and return a batched
:class:`DGLGraph` object that is independent of the :attr:`graph_list`, the batch
size of the returned graph is the length of :attr:`graph_list`.
Parameters
----------
graph_list : iterable
A collection of :class:`~dgl.DGLGraph` to be batched.
node_attrs : None, str or iterable
The node attributes to be batched. If ``None``, the returned :class:`DGLGraph`
object will not have any node attributes. By default, all node attributes will
be batched. If ``str`` or iterable, this should specify exactly what node
attributes to be batched.
edge_attrs : None, str or iterable, optional
Same as for the case of :attr:`node_attrs`
Returns
-------
DGLGraph
One single batched graph.
See Also
--------
unbatch
"""
if len(graph_list) == 1:
return graph_list[0]
def _init_attrs(attrs, mode):
"""Collect attributes of given mode (node/edge) from graph_list.
Parameters
----------
attrs: None or ALL or str or iterator
The attributes to collect. If ALL, check if all graphs have the same
attribute set and return the attribute set. If None, return an empty
list. If it is a string or a iterator of string, return these
attributes.
mode: str
Suggest to collect node attributes or edge attributes.
Returns
-------
Iterable
The obtained attribute set.
"""
if mode == 'node':
nitems_list = [g.number_of_nodes() for g in graph_list]
attrs_list = [set(g.node_attr_schemes().keys()) for g in graph_list]
else:
nitems_list = [g.number_of_edges() for g in graph_list]
attrs_list = [set(g.edge_attr_schemes().keys()) for g in graph_list]
if attrs is None:
return []
elif is_all(attrs):
attrs = set()
# Check if at least a graph has mode items and associated features.
for i, (g_num_items, g_attrs) in enumerate(zip(nitems_list, attrs_list)):
if g_num_items > 0 and len(g_attrs) > 0:
attrs = g_attrs
ref_g_index = i
break
# Check if all the graphs with mode items have the same associated features.
if len(attrs) > 0:
for i, (g_num_items, g_attrs) in enumerate(zip(nitems_list, attrs_list)):
if g_attrs != attrs and g_num_items > 0:
raise ValueError('Expect graph {0} and {1} to have the same {2} '
'attributes when {2}_attrs=ALL, got {3} and {4}.'
.format(ref_g_index, i, mode, attrs, g_attrs))
return attrs
elif isinstance(attrs, str):
return [attrs]
elif isinstance(attrs, Iterable):
return attrs
else:
raise ValueError('Expected {} attrs to be of type None str or Iterable, '
'got type {}'.format(mode, type(attrs)))
node_attrs = _init_attrs(node_attrs, 'node')
edge_attrs = _init_attrs(edge_attrs, 'edge')
# create batched graph index
batched_index = graph_index.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames
if len(node_attrs) == 0:
batched_node_frame = FrameRef(Frame(num_rows=batched_index.number_of_nodes()))
else:
# NOTE: following code will materialize the columns of the input graphs.
cols = {key: F.cat([gr._node_frame[key] for gr in graph_list
if gr.number_of_nodes() > 0], dim=0)
for key in node_attrs}
batched_node_frame = FrameRef(Frame(cols))
if len(edge_attrs) == 0:
batched_edge_frame = FrameRef(Frame(num_rows=batched_index.number_of_edges()))
else:
cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list
if gr.number_of_edges() > 0], dim=0)
for key in edge_attrs}
batched_edge_frame = FrameRef(Frame(cols))
batch_size = 0
batch_num_nodes = []
batch_num_edges = []
for grh in graph_list:
# handle the input is again a batched graph.
batch_size += grh.batch_size
batch_num_nodes += grh.batch_num_nodes
batch_num_edges += grh.batch_num_edges
return DGLGraph(graph_data=batched_index,
node_frame=batched_node_frame,
edge_frame=batched_edge_frame,
batch_num_nodes=batch_num_nodes,
batch_num_edges=batch_num_edges)
def unbatch(graph):
"""Return the list of graphs in this batch.
Parameters
----------
graph : DGLGraph
The batched graph.
Returns
-------
list
A list of :class:`~dgl.DGLGraph` objects whose attributes are obtained
by partitioning the attributes of the :attr:`graph`. The length of the
list is the same as the batch size of :attr:`graph`.
Notes
-----
Unbatching will break each field tensor of the batched graph into smaller
partitions.
For simpler tasks such as node/edge state aggregation, try to use
readout functions.
See Also
--------
batch
"""
if graph.batch_size == 1:
return [graph]
bsize = graph.batch_size
bnn = graph.batch_num_nodes
bne = graph.batch_num_edges
pttns = graph_index.disjoint_partition(graph._graph, utils.toindex(bnn))
# split the frames
node_frames = [FrameRef(Frame(num_rows=n)) for n in bnn]
edge_frames = [FrameRef(Frame(num_rows=n)) for n in bne]
for attr, col in graph._node_frame.items():
col_splits = F.split(col, bnn, dim=0)
for i in range(bsize):
node_frames[i][attr] = col_splits[i]
for attr, col in graph._edge_frame.items():
col_splits = F.split(col, bne, dim=0)
for i in range(bsize):
edge_frames[i][attr] = col_splits[i]
return [DGLGraph(graph_data=pttns[i],
node_frame=node_frames[i],
edge_frame=edge_frames[i]) for i in range(bsize)]
############################################################
# Internal APIs
############################################################
......
......@@ -1182,13 +1182,13 @@ def from_edge_list(elist, is_multigraph, readonly):
raise DGLError('Invalid edge list. Nodes must start from 0.')
return from_coo(num_nodes, src_ids, dst_ids, is_multigraph, readonly)
def map_to_subgraph_nid(subgraph, parent_nids):
def map_to_subgraph_nid(induced_nodes, parent_nids):
"""Map parent node Ids to the subgraph node Ids.
Parameters
----------
subgraph: SubgraphIndex
the graph index of a subgraph
induced_nodes: utils.Index
Induced nodes of the subgraph.
parent_nids: utils.Index
Node Ids in the parent graph.
......@@ -1198,7 +1198,7 @@ def map_to_subgraph_nid(subgraph, parent_nids):
utils.Index
Node Ids in the subgraph.
"""
return utils.toindex(_CAPI_DGLMapSubgraphNID(subgraph.induced_nodes.todgltensor(),
return utils.toindex(_CAPI_DGLMapSubgraphNID(induced_nodes.todgltensor(),
parent_nids.todgltensor()))
def transform_ids(mapping, ids):
......
......@@ -147,7 +147,7 @@ class GetContext(nn.Module):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
......@@ -202,7 +202,7 @@ class GNNLayer(nn.Module):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
......@@ -248,7 +248,7 @@ class GlobalPool(nn.Module):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
......@@ -326,7 +326,7 @@ class AttentiveFP(nn.Module):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment