Unverified Commit 93ac29ce authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Unify DGLGraph, BatchedDGLGraph and DGLSubGraph (#1216)



* upd

* upd

* upd

* lint

* fix

* fix test

* fix

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd tutorial

* upd

* upd

* fix kg

* upd doc organization

* refresh test

* upd

* refactor doc

* fix lint
Co-authored-by: default avatarMinjie Wang <minjie.wang@nyu.edu>
parent 8874e830
...@@ -8,6 +8,8 @@ import sys ...@@ -8,6 +8,8 @@ import sys
import pickle import pickle
import time import time
from dgl.base import NID, EID
def SoftRelationPartition(edges, n, threshold=0.05): def SoftRelationPartition(edges, n, threshold=0.05):
"""This partitions a list of edges to n partitions according to their """This partitions a list of edges to n partitions according to their
relation types. For any relation with number of edges larger than the relation types. For any relation with number of edges larger than the
...@@ -359,7 +361,7 @@ class TrainDataset(object): ...@@ -359,7 +361,7 @@ class TrainDataset(object):
return_false_neg=False) return_false_neg=False)
class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph): class ChunkNegEdgeSubgraph(dgl.DGLGraph):
"""Wrapper for negative graph """Wrapper for negative graph
Parameters Parameters
...@@ -378,7 +380,11 @@ class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph): ...@@ -378,7 +380,11 @@ class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
""" """
def __init__(self, subg, num_chunks, chunk_size, def __init__(self, subg, num_chunks, chunk_size,
neg_sample_size, neg_head): neg_sample_size, neg_head):
super(ChunkNegEdgeSubgraph, self).__init__(subg._parent, subg.sgi) super(ChunkNegEdgeSubgraph, self).__init__(graph_data=subg.sgi.graph,
readonly=True,
parent=subg._parent)
self.ndata[NID] = subg.sgi.induced_nodes.tousertensor()
self.edata[EID] = subg.sgi.induced_edges.tousertensor()
self.subg = subg self.subg = subg
self.num_chunks = num_chunks self.num_chunks = num_chunks
self.chunk_size = chunk_size self.chunk_size = chunk_size
......
"""MPNN""" """MPNN"""
import torch.nn as nn import torch.nn as nn
from dgl import BatchedDGLGraph
from dgl.nn.pytorch import Set2Set from dgl.nn.pytorch import Set2Set
from ..gnn import MPNNGNN from ..gnn import MPNNGNN
...@@ -77,6 +76,4 @@ class MPNNPredictor(nn.Module): ...@@ -77,6 +76,4 @@ class MPNNPredictor(nn.Module):
""" """
node_feats = self.gnn(g, node_feats, edge_feats) node_feats = self.gnn(g, node_feats, edge_feats)
graph_feats = self.readout(g, node_feats) graph_feats = self.readout(g, node_feats)
if not isinstance(g, BatchedDGLGraph):
graph_feats = graph_feats.unsqueeze(0)
return self.predict(graph_feats) return self.predict(graph_feats)
...@@ -4,8 +4,6 @@ import torch ...@@ -4,8 +4,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl import BatchedDGLGraph
__all__ = ['AttentiveFPReadout'] __all__ = ['AttentiveFPReadout']
class GlobalPool(nn.Module): class GlobalPool(nn.Module):
...@@ -59,10 +57,7 @@ class GlobalPool(nn.Module): ...@@ -59,10 +57,7 @@ class GlobalPool(nn.Module):
g.ndata['a'] = dgl.softmax_nodes(g, 'z') g.ndata['a'] = dgl.softmax_nodes(g, 'z')
g.ndata['hv'] = self.project_nodes(node_feats) g.ndata['hv'] = self.project_nodes(node_feats)
if isinstance(g, BatchedDGLGraph): g_repr = dgl.sum_nodes(g, 'hv', 'a')
g_repr = dgl.sum_nodes(g, 'hv', 'a')
else:
g_repr = dgl.sum_nodes(g, 'hv', 'a').unsqueeze(0)
context = F.elu(g_repr) context = F.elu(g_repr)
if get_node_weight: if get_node_weight:
...@@ -121,9 +116,6 @@ class AttentiveFPReadout(nn.Module): ...@@ -121,9 +116,6 @@ class AttentiveFPReadout(nn.Module):
g.ndata['hv'] = node_feats g.ndata['hv'] = node_feats
g_feats = dgl.sum_nodes(g, 'hv') g_feats = dgl.sum_nodes(g, 'hv')
if not isinstance(g, BatchedDGLGraph):
g_feats = g_feats.unsqueeze(0)
if get_node_weight: if get_node_weight:
node_weights = [] node_weights = []
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
import dgl import dgl
import torch.nn as nn import torch.nn as nn
from dgl import BatchedDGLGraph
__all__ = ['MLPNodeReadout'] __all__ = ['MLPNodeReadout']
class MLPNodeReadout(nn.Module): class MLPNodeReadout(nn.Module):
...@@ -64,6 +62,4 @@ class MLPNodeReadout(nn.Module): ...@@ -64,6 +62,4 @@ class MLPNodeReadout(nn.Module):
elif self.mode == 'sum': elif self.mode == 'sum':
graph_feats = dgl.sum_nodes(g, 'h') graph_feats = dgl.sum_nodes(g, 'h')
if not isinstance(g, BatchedDGLGraph):
graph_feats = graph_feats.unsqueeze(0)
return graph_feats return graph_feats
...@@ -3,7 +3,6 @@ import dgl ...@@ -3,7 +3,6 @@ import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
from dgl import BatchedDGLGraph
from dgl.nn.pytorch import WeightAndSum from dgl.nn.pytorch import WeightAndSum
__all__ = ['WeightedSumAndMax'] __all__ = ['WeightedSumAndMax']
...@@ -45,10 +44,5 @@ class WeightedSumAndMax(nn.Module): ...@@ -45,10 +44,5 @@ class WeightedSumAndMax(nn.Module):
with bg.local_scope(): with bg.local_scope():
bg.ndata['h'] = feats bg.ndata['h'] = feats
h_g_max = dgl.max_nodes(bg, 'h') h_g_max = dgl.max_nodes(bg, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
h_g_max = h_g_max.unsqueeze(0)
h_g = torch.cat([h_g_sum, h_g_max], dim=1) h_g = torch.cat([h_g_sum, h_g_max], dim=1)
return h_g return h_g
...@@ -93,8 +93,8 @@ def collate_molgraphs(data): ...@@ -93,8 +93,8 @@ def collate_molgraphs(data):
------- -------
smiles : list smiles : list
List of smiles List of smiles
bg : BatchedDGLGraph bg : DGLGraph
Batched DGLGraphs The batched DGLGraph.
labels : Tensor of dtype float32 and shape (B, T) labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and Batched datapoint labels. B is len(data) and
T is the number of total tasks. T is the number of total tasks.
......
...@@ -45,6 +45,24 @@ Querying graph structure ...@@ -45,6 +45,24 @@ Querying graph structure
DGLGraph.out_degree DGLGraph.out_degree
DGLGraph.out_degrees DGLGraph.out_degrees
Querying batch summary
----------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.batch_size
DGLGraph.batch_num_nodes
DGLGraph.batch_num_edges
Querying sub-graph/parent-graph belonging information
-----------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.parent
Removing nodes and edges Removing nodes and edges
------------------------ ------------------------
...@@ -66,6 +84,8 @@ Transforming graph ...@@ -66,6 +84,8 @@ Transforming graph
DGLGraph.line_graph DGLGraph.line_graph
DGLGraph.reverse DGLGraph.reverse
DGLGraph.readonly DGLGraph.readonly
DGLGraph.flatten
DGLGraph.detach_parent
Converting from/to other format Converting from/to other format
------------------------------- -------------------------------
...@@ -121,3 +141,29 @@ Computing with DGLGraph ...@@ -121,3 +141,29 @@ Computing with DGLGraph
DGLGraph.filter_nodes DGLGraph.filter_nodes
DGLGraph.filter_edges DGLGraph.filter_edges
DGLGraph.to DGLGraph.to
Batch and Unbatch
-------------------
.. autosummary::
:toctree: ../../generated/
batch
unbatch
Mapping between subgraph and parent graph
-----------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.parent_nid
DGLGraph.parent_eid
DGLGraph.map_to_subgraph_nid
Synchronize features between subgraph and parent graph
------------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.copy_from_parent
DGLGraph.copy_to_parent
...@@ -6,7 +6,7 @@ API Reference ...@@ -6,7 +6,7 @@ API Reference
graph graph
heterograph heterograph
batch readout
batch_heterograph batch_heterograph
nn nn
init init
...@@ -17,7 +17,6 @@ API Reference ...@@ -17,7 +17,6 @@ API Reference
sampler sampler
data data
transform transform
subgraph
graph_store graph_store
nodeflow nodeflow
random random
......
.. _apibatch: .. _apibatch:
dgl.batched_graph dgl.readout
================================================== ==================================================
.. currentmodule:: dgl .. currentmodule:: dgl
.. autoclass:: BatchedDGLGraph
Merge and decompose
-------------------
.. autosummary::
:toctree: ../../generated/
batch
unbatch
Query batch summary
----------------------
.. autosummary::
:toctree: ../../generated/
BatchedDGLGraph.batch_size
BatchedDGLGraph.batch_num_nodes
BatchedDGLGraph.batch_num_edges
Graph Readout Graph Readout
------------- -------------
......
.. _apisubgraph:
dgl.subgraph
================================================
.. currentmodule:: dgl.subgraph
.. autoclass:: DGLSubGraph
Mapping between subgraph and parent graph
-----------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLSubGraph.parent_nid
DGLSubGraph.parent_eid
DGLSubGraph.map_to_subgraph_nid
Synchronize features between subgraph and parent graph
------------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLSubGraph.copy_from_parent
DGLSubGraph.copy_to_parent
...@@ -101,12 +101,11 @@ a useful manual for in-depth developers. ...@@ -101,12 +101,11 @@ a useful manual for in-depth developers.
api/python/graph api/python/graph
api/python/heterograph api/python/heterograph
api/python/batch api/python/readout
api/python/batch_heterograph api/python/batch_heterograph
api/python/nn api/python/nn
api/python/function api/python/function
api/python/udf api/python/udf
api/python/subgraph
api/python/traversal api/python/traversal
api/python/propagate api/python/propagate
api/python/transform api/python/transform
......
...@@ -67,7 +67,7 @@ def load_data(args): ...@@ -67,7 +67,7 @@ def load_data(args):
test_dataset = PPIDataset('test') test_dataset = PPIDataset('test')
PPIDataType = namedtuple('PPIDataset', ['train_mask', 'test_mask', PPIDataType = namedtuple('PPIDataset', ['train_mask', 'test_mask',
'val_mask', 'features', 'labels', 'num_labels', 'graph']) 'val_mask', 'features', 'labels', 'num_labels', 'graph'])
G = dgl.BatchedDGLGraph( G = dgl.batch(
[train_dataset.graph, val_dataset.graph, test_dataset.graph], edge_attrs=None, node_attrs=None) [train_dataset.graph, val_dataset.graph, test_dataset.graph], edge_attrs=None, node_attrs=None)
G = G.to_networkx() G = G.to_networkx()
# hack to dodge the potential bugs of to_networkx # hack to dodge the potential bugs of to_networkx
......
...@@ -226,8 +226,8 @@ def collate_molgraphs(data): ...@@ -226,8 +226,8 @@ def collate_molgraphs(data):
------- -------
smiles : list smiles : list
List of smiles List of smiles
bg : BatchedDGLGraph bg : DGLGraph
Batched DGLGraphs The batched DGLGraph.
labels : Tensor of dtype float32 and shape (B, T) labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and Batched datapoint labels. B is len(data) and
T is the number of total tasks. T is the number of total tasks.
......
...@@ -19,10 +19,10 @@ from ._ffi.function import register_func, get_global_func, list_global_func_name ...@@ -19,10 +19,10 @@ from ._ffi.function import register_func, get_global_func, list_global_func_name
from ._ffi.base import DGLError, __version__ from ._ffi.base import DGLError, __version__
from .base import ALL, NTYPE, NID, ETYPE, EID from .base import ALL, NTYPE, NID, ETYPE, EID
from .batched_graph import * from .readout import *
from .batched_heterograph import * from .batched_heterograph import *
from .convert import * from .convert import *
from .graph import DGLGraph from .graph import DGLGraph, batch, unbatch
from .generators import * from .generators import *
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from .nodeflow import * from .nodeflow import *
......
...@@ -12,7 +12,8 @@ from ..._ffi.ndarray import empty ...@@ -12,7 +12,8 @@ from ..._ffi.ndarray import empty
from ... import utils from ... import utils
from ...nodeflow import NodeFlow from ...nodeflow import NodeFlow
from ... import backend as F from ... import backend as F
from ... import subgraph from ...graph import DGLGraph
from ...base import NID, EID
try: try:
import Queue as queue import Queue as queue
...@@ -431,13 +432,17 @@ class LayerSampler(NodeFlowSampler): ...@@ -431,13 +432,17 @@ class LayerSampler(NodeFlowSampler):
nflows = [NodeFlow(self.g, obj) for obj in nfobjs] nflows = [NodeFlow(self.g, obj) for obj in nfobjs]
return nflows return nflows
class EdgeSubgraph(subgraph.DGLSubGraph): class EdgeSubgraph(DGLGraph):
''' The subgraph sampled from an edge sampler. ''' The subgraph sampled from an edge sampler.
A user can access the head nodes and tail nodes of the subgraph directly. A user can access the head nodes and tail nodes of the subgraph directly.
''' '''
def __init__(self, parent, sgi, neg): def __init__(self, parent, sgi, neg):
super(EdgeSubgraph, self).__init__(parent, sgi) super(EdgeSubgraph, self).__init__(graph_data=sgi.graph,
readonly=True,
parent=parent)
self.ndata[NID] = sgi.induced_nodes.tousertensor()
self.edata[EID] = sgi.induced_edges.tousertensor()
self.sgi = sgi self.sgi = sgi
self.neg = neg self.neg = neg
self.head = None self.head = None
...@@ -735,7 +740,9 @@ class EdgeSampler(object): ...@@ -735,7 +740,9 @@ class EdgeSampler(object):
if self._negative_mode == "": if self._negative_mode == "":
# If no negative subgraphs. # If no negative subgraphs.
return [subgraph.DGLSubGraph(self.g, subg) for subg in subgs] return [self.g._create_subgraph(subg,
subg.induced_nodes,
subg.induced_edges) for subg in subgs]
else: else:
rets = [] rets = []
assert len(subgs) % 2 == 0 assert len(subgs) % 2 == 0
......
"""For Graph Serialization""" """For Graph Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
from ..graph import DGLGraph from ..graph import DGLGraph
from ..batched_graph import BatchedDGLGraph
from .._ffi.object import ObjectBase, register_object from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
...@@ -29,7 +28,7 @@ class GraphData(ObjectBase): ...@@ -29,7 +28,7 @@ class GraphData(ObjectBase):
@staticmethod @staticmethod
def create(g: DGLGraph): def create(g: DGLGraph):
"""Create GraphData""" """Create GraphData"""
assert not isinstance(g, BatchedDGLGraph), "BatchedDGLGraph is not supported for serialization" assert g.batch_size == 1, "BatchedDGLGraph is not supported for serialization"
ghandle = g._graph ghandle = g._graph
if len(g.ndata) != 0: if len(g.ndata) != 0:
node_tensors = dict() node_tensors = dict()
......
"""Dataset for stochastic block model.""" """Dataset for stochastic block model."""
import math import math
import os
import pickle
import random import random
import numpy as np import numpy as np
import numpy.random as npr import numpy.random as npr
import scipy as sp import scipy as sp
import networkx as nx
from ..batched_graph import batch from ..graph import DGLGraph, batch
from ..graph import DGLGraph
from ..utils import Index from ..utils import Index
def sbm(n_blocks, block_size, p, q, rng=None): def sbm(n_blocks, block_size, p, q, rng=None):
......
...@@ -3,10 +3,11 @@ from __future__ import absolute_import ...@@ -3,10 +3,11 @@ from __future__ import absolute_import
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Iterable
import networkx as nx import networkx as nx
import dgl import dgl
from .base import ALL, is_all, DGLError, dgl_warning from .base import ALL, NID, EID, is_all, DGLError, dgl_warning
from . import backend as F from . import backend as F
from . import init from . import init
from .frame import FrameRef, Frame, Scheme, sync_frame_initializer from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
...@@ -16,7 +17,7 @@ from . import utils ...@@ -16,7 +17,7 @@ from . import utils
from .view import NodeView, EdgeView from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
__all__ = ['DGLGraph'] __all__ = ['DGLGraph', 'batch', 'unbatch']
class DGLBaseGraph(object): class DGLBaseGraph(object):
"""Base graph class. """Base graph class.
...@@ -734,6 +735,24 @@ class DGLBaseGraph(object): ...@@ -734,6 +735,24 @@ class DGLBaseGraph(object):
v = utils.toindex(v) v = utils.toindex(v)
return self._graph.out_degrees(v).tousertensor() return self._graph.out_degrees(v).tousertensor()
def mutation(func):
"""A decorator to decorate functions that might change graph structure."""
def inner(g, *args, **kwargs):
if g.is_readonly:
raise DGLError("Readonly graph. Mutation is not allowed.")
if g.batch_size > 1:
dgl_warning("The graph has batch_size > 1, and mutation would break"
" batching related properties, call `flatten` to remove"
" batching information of the graph.")
if g._parent is not None:
dgl_warning("The graph is a subgraph of a parent graph, and mutation"
" would break subgraph related properties, call `detach"
"_parent` to remove its connection with its parent.")
func(g, *args, **kwargs)
return inner
class DGLGraph(DGLBaseGraph): class DGLGraph(DGLBaseGraph):
"""Base graph class. """Base graph class.
...@@ -902,7 +921,10 @@ class DGLGraph(DGLBaseGraph): ...@@ -902,7 +921,10 @@ class DGLGraph(DGLBaseGraph):
edge_frame=None, edge_frame=None,
multigraph=None, multigraph=None,
readonly=False, readonly=False,
sort_csr=False): sort_csr=False,
batch_num_nodes=None,
batch_num_edges=None,
parent=None):
# graph # graph
if isinstance(graph_data, DGLGraph): if isinstance(graph_data, DGLGraph):
gidx = graph_data._graph gidx = graph_data._graph
...@@ -936,6 +958,22 @@ class DGLGraph(DGLBaseGraph): ...@@ -936,6 +958,22 @@ class DGLGraph(DGLBaseGraph):
self._apply_node_func = None self._apply_node_func = None
self._apply_edge_func = None self._apply_edge_func = None
# batched graph
self._batch_num_nodes = batch_num_nodes
self._batch_num_edges = batch_num_edges
# set parent if the graph is a induced subgraph.
self._parent = parent
def _create_subgraph(self, sgi, induced_nodes, induced_edges):
"""Internal function to create a subgraph from index."""
subg = DGLGraph(graph_data=sgi.graph,
readonly=True,
parent=self)
subg.ndata[NID] = induced_nodes.tousertensor()
subg.edata[EID] = induced_edges.tousertensor()
return subg
def _get_msg_index(self): def _get_msg_index(self):
if self._msg_index is None: if self._msg_index is None:
self._msg_index = utils.zero_index(size=self.number_of_edges()) self._msg_index = utils.zero_index(size=self.number_of_edges())
...@@ -944,6 +982,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -944,6 +982,7 @@ class DGLGraph(DGLBaseGraph):
def _set_msg_index(self, index): def _set_msg_index(self, index):
self._msg_index = index self._msg_index = index
@mutation
def add_nodes(self, num, data=None): def add_nodes(self, num, data=None):
"""Add multiple new nodes. """Add multiple new nodes.
...@@ -995,6 +1034,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -995,6 +1034,7 @@ class DGLGraph(DGLBaseGraph):
else: else:
self._node_frame.append(data) self._node_frame.append(data)
@mutation
def add_edge(self, u, v, data=None): def add_edge(self, u, v, data=None):
"""Add one new edge between u and v. """Add one new edge between u and v.
...@@ -1053,6 +1093,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -1053,6 +1093,7 @@ class DGLGraph(DGLBaseGraph):
self._msg_index = self._msg_index.append_zeros(1) self._msg_index = self._msg_index.append_zeros(1)
self._msg_frame.add_rows(1) self._msg_frame.add_rows(1)
@mutation
def add_edges(self, u, v, data=None): def add_edges(self, u, v, data=None):
"""Add multiple edges for list of source nodes u and destination nodes """Add multiple edges for list of source nodes u and destination nodes
v. A single edge is added between every pair of ``u[i]`` and ``v[i]``. v. A single edge is added between every pair of ``u[i]`` and ``v[i]``.
...@@ -1114,6 +1155,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -1114,6 +1155,7 @@ class DGLGraph(DGLBaseGraph):
self._msg_index = self._msg_index.append_zeros(num) self._msg_index = self._msg_index.append_zeros(num)
self._msg_frame.add_rows(num) self._msg_frame.add_rows(num)
@mutation
def remove_nodes(self, vids): def remove_nodes(self, vids):
"""Remove multiple nodes, edges that have connection with these nodes would also be removed. """Remove multiple nodes, edges that have connection with these nodes would also be removed.
...@@ -1163,8 +1205,6 @@ class DGLGraph(DGLBaseGraph): ...@@ -1163,8 +1205,6 @@ class DGLGraph(DGLBaseGraph):
add_edges add_edges
remove_edges remove_edges
""" """
if self.is_readonly:
raise DGLError("remove_nodes is not supported by read-only graph.")
induced_nodes = utils.set_diff(utils.toindex(self.nodes()), utils.toindex(vids)) induced_nodes = utils.set_diff(utils.toindex(self.nodes()), utils.toindex(vids))
sgi = self._graph.node_subgraph(induced_nodes) sgi = self._graph.node_subgraph(induced_nodes)
...@@ -1180,6 +1220,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -1180,6 +1220,7 @@ class DGLGraph(DGLBaseGraph):
self._graph = sgi.graph self._graph = sgi.graph
@mutation
def remove_edges(self, eids): def remove_edges(self, eids):
"""Remove multiple edges. """Remove multiple edges.
...@@ -1226,8 +1267,6 @@ class DGLGraph(DGLBaseGraph): ...@@ -1226,8 +1267,6 @@ class DGLGraph(DGLBaseGraph):
add_edges add_edges
remove_nodes remove_nodes
""" """
if self.is_readonly:
raise DGLError("remove_edges is not supported by read-only graph.")
induced_edges = utils.set_diff( induced_edges = utils.set_diff(
utils.toindex(range(self.number_of_edges())), utils.toindex(eids)) utils.toindex(range(self.number_of_edges())), utils.toindex(eids))
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=True) sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=True)
...@@ -1244,6 +1283,112 @@ class DGLGraph(DGLBaseGraph): ...@@ -1244,6 +1283,112 @@ class DGLGraph(DGLBaseGraph):
self._graph = sgi.graph self._graph = sgi.graph
@property
def parent_nid(self):
"""Get the parent node ids.
The returned tensor can be used as a map from the node id
in this subgraph to the node id in the parent graph.
Returns
-------
Tensor
The parent node id array.
"""
if self._parent is None:
raise DGLError("We only support parent_nid for subgraphs.")
return self.ndata[NID]
@property
def parent_eid(self):
"""Get the parent edge ids.
The returned tensor can be used as a map from the edge id
in this subgraph to the edge id in the parent graph.
Returns
-------
Tensor
The parent edge id array.
"""
if self._parent is None:
raise DGLError("We only support parent_eid for subgraphs.")
return self.edata[EID]
def copy_to_parent(self, inplace=False):
"""Write node/edge features to the parent graph.
Parameters
----------
inplace : bool
If true, use inplace write (no gradient but faster)
"""
if self._parent is None:
raise DGLError("We only support copy_to_parent for subgraphs.")
nids = self.ndata.pop(NID)
eids = self.edata.pop(EID)
self._parent._node_frame.update_rows(
utils.toindex(nids), self._node_frame, inplace=inplace)
if self._parent._edge_frame.num_rows != 0:
self._parent._edge_frame.update_rows(
utils.toindex(eids), self._edge_frame, inplace=inplace)
self.ndata[NID] = nids
self.edata[EID] = eids
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
All old features will be removed.
"""
if self._parent is None:
raise DGLError("We only support copy_from_parent for subgraphs.")
nids = self.ndata[NID]
eids = self.edata[EID]
if self._parent._node_frame.num_rows != 0 and self._parent._node_frame.num_columns != 0:
self._node_frame = FrameRef(Frame(
self._parent._node_frame[utils.toindex(nids)]))
if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0:
self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[utils.toindex(eids)]))
self.ndata[NID] = nids
self.edata[NID] = eids
def map_to_subgraph_nid(self, parent_vids):
"""Map the node Ids in the parent graph to the node Ids in the subgraph.
Parameters
----------
parent_vids : list, tensor
The node ID array in the parent graph.
Returns
-------
tensor
The node ID array in the subgraph.
"""
if self._parent is None:
raise DGLError("We only support map_to_subgraph_nid for subgraphs.")
v = graph_index.map_to_subgraph_nid(
utils.toindex(self.ndata[NID]), utils.toindex(parent_vids))
return v.tousertensor()
def flatten(self):
"""Remove all batching information of the graph, and regard the current
graph as an independent graph rather then a batched graph.
Graph topology and attributes would not be affected.
"""
self._batch_num_nodes = None
self._batch_num_edges = None
def detach_parent(self):
"""Detach the current graph from its parent, and regard the current graph
as an independent graph rather then a subgraph.
Graph topology and attributes would not be affected.
"""
self._parent = None
self.ndata.pop(NID)
self.edata.pop(EID)
def clear(self): def clear(self):
"""Remove all nodes and edges, as well as their features, from the """Remove all nodes and edges, as well as their features, from the
graph. graph.
...@@ -1710,6 +1855,53 @@ class DGLGraph(DGLBaseGraph): ...@@ -1710,6 +1855,53 @@ class DGLGraph(DGLBaseGraph):
""" """
return self.edges[:].data return self.edges[:].data
@property
def batch_size(self):
"""Number of graphs in this batch.
Returns
-------
int
Number of graphs in this batch."""
return 1 if self.batch_num_nodes is None else len(self.batch_num_nodes)
@property
def batch_num_nodes(self):
"""Number of nodes of each graph in this batch.
Returns
-------
list
Number of nodes of each graph in this batch."""
if self._batch_num_nodes is None:
return [self.number_of_nodes()]
else:
return self._batch_num_nodes
@property
def batch_num_edges(self):
"""Number of edges of each graph in this batch.
Returns
-------
list
Number of edges of each graph in this batch."""
if self._batch_num_edges is None:
return [self.number_of_edges()]
else:
return self._batch_num_edges
@property
def parent(self):
"""If current graph is a induced subgraph of a parent graph, return
its parent graph, else return None.
Returns
-------
DGLGraph or None
The parent graph of current graph.
"""
return self._parent
def init_ndata(self, ndata_name, shape, dtype, ctx=F.cpu()): def init_ndata(self, ndata_name, shape, dtype, ctx=F.cpu()):
"""Create node embedding. """Create node embedding.
...@@ -2914,7 +3106,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2914,7 +3106,7 @@ class DGLGraph(DGLBaseGraph):
Returns Returns
------- -------
G : DGLSubGraph G : DGLGraph
The subgraph. The subgraph.
The nodes are relabeled so that node `i` in the subgraph is mapped The nodes are relabeled so that node `i` in the subgraph is mapped
to node `nodes[i]` in the original graph. to node `nodes[i]` in the original graph.
...@@ -2942,14 +3134,12 @@ class DGLGraph(DGLBaseGraph): ...@@ -2942,14 +3134,12 @@ class DGLGraph(DGLBaseGraph):
See Also See Also
-------- --------
DGLSubGraph
subgraphs subgraphs
edge_subgraph edge_subgraph
""" """
from . import subgraph
induced_nodes = utils.toindex(nodes) induced_nodes = utils.toindex(nodes)
sgi = self._graph.node_subgraph(induced_nodes) sgi = self._graph.node_subgraph(induced_nodes)
return subgraph.DGLSubGraph(self, sgi) return self._create_subgraph(sgi, sgi.induced_nodes, sgi.induced_edges)
def subgraphs(self, nodes): def subgraphs(self, nodes):
"""Return a list of subgraphs, each induced in the corresponding given """Return a list of subgraphs, each induced in the corresponding given
...@@ -2966,18 +3156,17 @@ class DGLGraph(DGLBaseGraph): ...@@ -2966,18 +3156,17 @@ class DGLGraph(DGLBaseGraph):
Returns Returns
------- -------
G : A list of DGLSubGraph G : A list of DGLGraph
The subgraphs. The subgraphs.
See Also See Also
-------- --------
DGLSubGraph
subgraph subgraph
""" """
from . import subgraph
induced_nodes = [utils.toindex(n) for n in nodes] induced_nodes = [utils.toindex(n) for n in nodes]
sgis = self._graph.node_subgraphs(induced_nodes) sgis = self._graph.node_subgraphs(induced_nodes)
return [subgraph.DGLSubGraph(self, sgi) for sgi in sgis] return [self._create_subgraph(
sgi, sgi.induced_nodes, sgi.induced_edges) for sgi in sgis]
def edge_subgraph(self, edges, preserve_nodes=False): def edge_subgraph(self, edges, preserve_nodes=False):
"""Return the subgraph induced on given edges. """Return the subgraph induced on given edges.
...@@ -2994,7 +3183,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2994,7 +3183,7 @@ class DGLGraph(DGLBaseGraph):
Returns Returns
------- -------
G : DGLSubGraph G : DGLGraph
The subgraph. The subgraph.
The edges are relabeled so that edge `i` in the subgraph is mapped The edges are relabeled so that edge `i` in the subgraph is mapped
to edge `edges[i]` in the original graph. to edge `edges[i]` in the original graph.
...@@ -3031,13 +3220,11 @@ class DGLGraph(DGLBaseGraph): ...@@ -3031,13 +3220,11 @@ class DGLGraph(DGLBaseGraph):
See Also See Also
-------- --------
DGLSubGraph
subgraph subgraph
""" """
from . import subgraph
induced_edges = utils.toindex(edges) induced_edges = utils.toindex(edges)
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=preserve_nodes) sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=preserve_nodes)
return subgraph.DGLSubGraph(self, sgi) return self._create_subgraph(sgi, sgi.induced_nodes, sgi.induced_edges)
def adjacency_matrix_scipy(self, transpose=None, fmt='csr', return_edge_ids=None): def adjacency_matrix_scipy(self, transpose=None, fmt='csr', return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph. """Return the scipy adjacency matrix representation of this graph.
...@@ -3427,9 +3614,13 @@ class DGLGraph(DGLBaseGraph): ...@@ -3427,9 +3614,13 @@ class DGLGraph(DGLBaseGraph):
# otherwise the default initializer will be used. # otherwise the default initializer will be used.
sync_frame_initializer(local_node_frame._frame, self._node_frame._frame) sync_frame_initializer(local_node_frame._frame, self._node_frame._frame)
sync_frame_initializer(local_edge_frame._frame, self._edge_frame._frame) sync_frame_initializer(local_edge_frame._frame, self._edge_frame._frame)
return DGLGraph(self._graph, return DGLGraph(graph_data=self._graph,
local_node_frame, node_frame=local_node_frame,
local_edge_frame) edge_frame=local_edge_frame,
readonly=self.is_readonly,
batch_num_nodes=self.batch_num_nodes,
batch_num_edges=self.batch_num_edges,
parent=self._parent)
@contextmanager @contextmanager
def local_scope(self): def local_scope(self):
...@@ -3489,6 +3680,178 @@ class DGLGraph(DGLBaseGraph): ...@@ -3489,6 +3680,178 @@ class DGLGraph(DGLBaseGraph):
self._node_frame = old_nframe self._node_frame = old_nframe
self._edge_frame = old_eframe self._edge_frame = old_eframe
############################################################
# Batch/Unbatch APIs
############################################################
def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
"""Batch a collection of :class:`~dgl.DGLGraph` and return a batched
:class:`DGLGraph` object that is independent of the :attr:`graph_list`, the batch
size of the returned graph is the length of :attr:`graph_list`.
Parameters
----------
graph_list : iterable
A collection of :class:`~dgl.DGLGraph` to be batched.
node_attrs : None, str or iterable
The node attributes to be batched. If ``None``, the returned :class:`DGLGraph`
object will not have any node attributes. By default, all node attributes will
be batched. If ``str`` or iterable, this should specify exactly what node
attributes to be batched.
edge_attrs : None, str or iterable, optional
Same as for the case of :attr:`node_attrs`
Returns
-------
DGLGraph
One single batched graph.
See Also
--------
unbatch
"""
if len(graph_list) == 1:
return graph_list[0]
def _init_attrs(attrs, mode):
"""Collect attributes of given mode (node/edge) from graph_list.
Parameters
----------
attrs: None or ALL or str or iterator
The attributes to collect. If ALL, check if all graphs have the same
attribute set and return the attribute set. If None, return an empty
list. If it is a string or a iterator of string, return these
attributes.
mode: str
Suggest to collect node attributes or edge attributes.
Returns
-------
Iterable
The obtained attribute set.
"""
if mode == 'node':
nitems_list = [g.number_of_nodes() for g in graph_list]
attrs_list = [set(g.node_attr_schemes().keys()) for g in graph_list]
else:
nitems_list = [g.number_of_edges() for g in graph_list]
attrs_list = [set(g.edge_attr_schemes().keys()) for g in graph_list]
if attrs is None:
return []
elif is_all(attrs):
attrs = set()
# Check if at least a graph has mode items and associated features.
for i, (g_num_items, g_attrs) in enumerate(zip(nitems_list, attrs_list)):
if g_num_items > 0 and len(g_attrs) > 0:
attrs = g_attrs
ref_g_index = i
break
# Check if all the graphs with mode items have the same associated features.
if len(attrs) > 0:
for i, (g_num_items, g_attrs) in enumerate(zip(nitems_list, attrs_list)):
if g_attrs != attrs and g_num_items > 0:
raise ValueError('Expect graph {0} and {1} to have the same {2} '
'attributes when {2}_attrs=ALL, got {3} and {4}.'
.format(ref_g_index, i, mode, attrs, g_attrs))
return attrs
elif isinstance(attrs, str):
return [attrs]
elif isinstance(attrs, Iterable):
return attrs
else:
raise ValueError('Expected {} attrs to be of type None str or Iterable, '
'got type {}'.format(mode, type(attrs)))
node_attrs = _init_attrs(node_attrs, 'node')
edge_attrs = _init_attrs(edge_attrs, 'edge')
# create batched graph index
batched_index = graph_index.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames
if len(node_attrs) == 0:
batched_node_frame = FrameRef(Frame(num_rows=batched_index.number_of_nodes()))
else:
# NOTE: following code will materialize the columns of the input graphs.
cols = {key: F.cat([gr._node_frame[key] for gr in graph_list
if gr.number_of_nodes() > 0], dim=0)
for key in node_attrs}
batched_node_frame = FrameRef(Frame(cols))
if len(edge_attrs) == 0:
batched_edge_frame = FrameRef(Frame(num_rows=batched_index.number_of_edges()))
else:
cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list
if gr.number_of_edges() > 0], dim=0)
for key in edge_attrs}
batched_edge_frame = FrameRef(Frame(cols))
batch_size = 0
batch_num_nodes = []
batch_num_edges = []
for grh in graph_list:
# handle the input is again a batched graph.
batch_size += grh.batch_size
batch_num_nodes += grh.batch_num_nodes
batch_num_edges += grh.batch_num_edges
return DGLGraph(graph_data=batched_index,
node_frame=batched_node_frame,
edge_frame=batched_edge_frame,
batch_num_nodes=batch_num_nodes,
batch_num_edges=batch_num_edges)
def unbatch(graph):
"""Return the list of graphs in this batch.
Parameters
----------
graph : DGLGraph
The batched graph.
Returns
-------
list
A list of :class:`~dgl.DGLGraph` objects whose attributes are obtained
by partitioning the attributes of the :attr:`graph`. The length of the
list is the same as the batch size of :attr:`graph`.
Notes
-----
Unbatching will break each field tensor of the batched graph into smaller
partitions.
For simpler tasks such as node/edge state aggregation, try to use
readout functions.
See Also
--------
batch
"""
if graph.batch_size == 1:
return [graph]
bsize = graph.batch_size
bnn = graph.batch_num_nodes
bne = graph.batch_num_edges
pttns = graph_index.disjoint_partition(graph._graph, utils.toindex(bnn))
# split the frames
node_frames = [FrameRef(Frame(num_rows=n)) for n in bnn]
edge_frames = [FrameRef(Frame(num_rows=n)) for n in bne]
for attr, col in graph._node_frame.items():
col_splits = F.split(col, bnn, dim=0)
for i in range(bsize):
node_frames[i][attr] = col_splits[i]
for attr, col in graph._edge_frame.items():
col_splits = F.split(col, bne, dim=0)
for i in range(bsize):
edge_frames[i][attr] = col_splits[i]
return [DGLGraph(graph_data=pttns[i],
node_frame=node_frames[i],
edge_frame=edge_frames[i]) for i in range(bsize)]
############################################################ ############################################################
# Internal APIs # Internal APIs
############################################################ ############################################################
......
...@@ -1182,13 +1182,13 @@ def from_edge_list(elist, is_multigraph, readonly): ...@@ -1182,13 +1182,13 @@ def from_edge_list(elist, is_multigraph, readonly):
raise DGLError('Invalid edge list. Nodes must start from 0.') raise DGLError('Invalid edge list. Nodes must start from 0.')
return from_coo(num_nodes, src_ids, dst_ids, is_multigraph, readonly) return from_coo(num_nodes, src_ids, dst_ids, is_multigraph, readonly)
def map_to_subgraph_nid(subgraph, parent_nids): def map_to_subgraph_nid(induced_nodes, parent_nids):
"""Map parent node Ids to the subgraph node Ids. """Map parent node Ids to the subgraph node Ids.
Parameters Parameters
---------- ----------
subgraph: SubgraphIndex induced_nodes: utils.Index
the graph index of a subgraph Induced nodes of the subgraph.
parent_nids: utils.Index parent_nids: utils.Index
Node Ids in the parent graph. Node Ids in the parent graph.
...@@ -1198,7 +1198,7 @@ def map_to_subgraph_nid(subgraph, parent_nids): ...@@ -1198,7 +1198,7 @@ def map_to_subgraph_nid(subgraph, parent_nids):
utils.Index utils.Index
Node Ids in the subgraph. Node Ids in the subgraph.
""" """
return utils.toindex(_CAPI_DGLMapSubgraphNID(subgraph.induced_nodes.todgltensor(), return utils.toindex(_CAPI_DGLMapSubgraphNID(induced_nodes.todgltensor(),
parent_nids.todgltensor())) parent_nids.todgltensor()))
def transform_ids(mapping, ids): def transform_ids(mapping, ids):
......
...@@ -147,7 +147,7 @@ class GetContext(nn.Module): ...@@ -147,7 +147,7 @@ class GetContext(nn.Module):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph or BatchedDGLGraph g : DGLGraph
Constructed DGLGraphs. Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1) node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size. Input node features. V for the number of nodes and N1 for the feature size.
...@@ -202,7 +202,7 @@ class GNNLayer(nn.Module): ...@@ -202,7 +202,7 @@ class GNNLayer(nn.Module):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph or BatchedDGLGraph g : DGLGraph
Constructed DGLGraphs. Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1) node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size. Input node features. V for the number of nodes and N1 for the feature size.
...@@ -248,7 +248,7 @@ class GlobalPool(nn.Module): ...@@ -248,7 +248,7 @@ class GlobalPool(nn.Module):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph or BatchedDGLGraph g : DGLGraph
Constructed DGLGraphs. Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1) node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size. Input node features. V for the number of nodes and N1 for the feature size.
...@@ -326,7 +326,7 @@ class AttentiveFP(nn.Module): ...@@ -326,7 +326,7 @@ class AttentiveFP(nn.Module):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph or BatchedDGLGraph g : DGLGraph
Constructed DGLGraphs. Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1) node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size. Input node features. V for the number of nodes and N1 for the feature size.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment