"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "0ab8fe49bf540b4f34ac5934c304da23ffd448e5"
Unverified Commit 93ac29ce authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Unify DGLGraph, BatchedDGLGraph and DGLSubGraph (#1216)



* upd

* upd

* upd

* lint

* fix

* fix test

* fix

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd tutorial

* upd

* upd

* fix kg

* upd doc organization

* refresh test

* upd

* refactor doc

* fix lint
Co-authored-by: default avatarMinjie Wang <minjie.wang@nyu.edu>
parent 8874e830
...@@ -8,6 +8,8 @@ import sys ...@@ -8,6 +8,8 @@ import sys
import pickle import pickle
import time import time
from dgl.base import NID, EID
def SoftRelationPartition(edges, n, threshold=0.05): def SoftRelationPartition(edges, n, threshold=0.05):
"""This partitions a list of edges to n partitions according to their """This partitions a list of edges to n partitions according to their
relation types. For any relation with number of edges larger than the relation types. For any relation with number of edges larger than the
...@@ -359,7 +361,7 @@ class TrainDataset(object): ...@@ -359,7 +361,7 @@ class TrainDataset(object):
return_false_neg=False) return_false_neg=False)
class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph): class ChunkNegEdgeSubgraph(dgl.DGLGraph):
"""Wrapper for negative graph """Wrapper for negative graph
Parameters Parameters
...@@ -378,7 +380,11 @@ class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph): ...@@ -378,7 +380,11 @@ class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
""" """
def __init__(self, subg, num_chunks, chunk_size, def __init__(self, subg, num_chunks, chunk_size,
neg_sample_size, neg_head): neg_sample_size, neg_head):
super(ChunkNegEdgeSubgraph, self).__init__(subg._parent, subg.sgi) super(ChunkNegEdgeSubgraph, self).__init__(graph_data=subg.sgi.graph,
readonly=True,
parent=subg._parent)
self.ndata[NID] = subg.sgi.induced_nodes.tousertensor()
self.edata[EID] = subg.sgi.induced_edges.tousertensor()
self.subg = subg self.subg = subg
self.num_chunks = num_chunks self.num_chunks = num_chunks
self.chunk_size = chunk_size self.chunk_size = chunk_size
......
"""MPNN""" """MPNN"""
import torch.nn as nn import torch.nn as nn
from dgl import BatchedDGLGraph
from dgl.nn.pytorch import Set2Set from dgl.nn.pytorch import Set2Set
from ..gnn import MPNNGNN from ..gnn import MPNNGNN
...@@ -77,6 +76,4 @@ class MPNNPredictor(nn.Module): ...@@ -77,6 +76,4 @@ class MPNNPredictor(nn.Module):
""" """
node_feats = self.gnn(g, node_feats, edge_feats) node_feats = self.gnn(g, node_feats, edge_feats)
graph_feats = self.readout(g, node_feats) graph_feats = self.readout(g, node_feats)
if not isinstance(g, BatchedDGLGraph):
graph_feats = graph_feats.unsqueeze(0)
return self.predict(graph_feats) return self.predict(graph_feats)
...@@ -4,8 +4,6 @@ import torch ...@@ -4,8 +4,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl import BatchedDGLGraph
__all__ = ['AttentiveFPReadout'] __all__ = ['AttentiveFPReadout']
class GlobalPool(nn.Module): class GlobalPool(nn.Module):
...@@ -59,10 +57,7 @@ class GlobalPool(nn.Module): ...@@ -59,10 +57,7 @@ class GlobalPool(nn.Module):
g.ndata['a'] = dgl.softmax_nodes(g, 'z') g.ndata['a'] = dgl.softmax_nodes(g, 'z')
g.ndata['hv'] = self.project_nodes(node_feats) g.ndata['hv'] = self.project_nodes(node_feats)
if isinstance(g, BatchedDGLGraph): g_repr = dgl.sum_nodes(g, 'hv', 'a')
g_repr = dgl.sum_nodes(g, 'hv', 'a')
else:
g_repr = dgl.sum_nodes(g, 'hv', 'a').unsqueeze(0)
context = F.elu(g_repr) context = F.elu(g_repr)
if get_node_weight: if get_node_weight:
...@@ -121,9 +116,6 @@ class AttentiveFPReadout(nn.Module): ...@@ -121,9 +116,6 @@ class AttentiveFPReadout(nn.Module):
g.ndata['hv'] = node_feats g.ndata['hv'] = node_feats
g_feats = dgl.sum_nodes(g, 'hv') g_feats = dgl.sum_nodes(g, 'hv')
if not isinstance(g, BatchedDGLGraph):
g_feats = g_feats.unsqueeze(0)
if get_node_weight: if get_node_weight:
node_weights = [] node_weights = []
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
import dgl import dgl
import torch.nn as nn import torch.nn as nn
from dgl import BatchedDGLGraph
__all__ = ['MLPNodeReadout'] __all__ = ['MLPNodeReadout']
class MLPNodeReadout(nn.Module): class MLPNodeReadout(nn.Module):
...@@ -64,6 +62,4 @@ class MLPNodeReadout(nn.Module): ...@@ -64,6 +62,4 @@ class MLPNodeReadout(nn.Module):
elif self.mode == 'sum': elif self.mode == 'sum':
graph_feats = dgl.sum_nodes(g, 'h') graph_feats = dgl.sum_nodes(g, 'h')
if not isinstance(g, BatchedDGLGraph):
graph_feats = graph_feats.unsqueeze(0)
return graph_feats return graph_feats
...@@ -3,7 +3,6 @@ import dgl ...@@ -3,7 +3,6 @@ import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
from dgl import BatchedDGLGraph
from dgl.nn.pytorch import WeightAndSum from dgl.nn.pytorch import WeightAndSum
__all__ = ['WeightedSumAndMax'] __all__ = ['WeightedSumAndMax']
...@@ -45,10 +44,5 @@ class WeightedSumAndMax(nn.Module): ...@@ -45,10 +44,5 @@ class WeightedSumAndMax(nn.Module):
with bg.local_scope(): with bg.local_scope():
bg.ndata['h'] = feats bg.ndata['h'] = feats
h_g_max = dgl.max_nodes(bg, 'h') h_g_max = dgl.max_nodes(bg, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
h_g_max = h_g_max.unsqueeze(0)
h_g = torch.cat([h_g_sum, h_g_max], dim=1) h_g = torch.cat([h_g_sum, h_g_max], dim=1)
return h_g return h_g
...@@ -93,8 +93,8 @@ def collate_molgraphs(data): ...@@ -93,8 +93,8 @@ def collate_molgraphs(data):
------- -------
smiles : list smiles : list
List of smiles List of smiles
bg : BatchedDGLGraph bg : DGLGraph
Batched DGLGraphs The batched DGLGraph.
labels : Tensor of dtype float32 and shape (B, T) labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and Batched datapoint labels. B is len(data) and
T is the number of total tasks. T is the number of total tasks.
......
...@@ -45,6 +45,24 @@ Querying graph structure ...@@ -45,6 +45,24 @@ Querying graph structure
DGLGraph.out_degree DGLGraph.out_degree
DGLGraph.out_degrees DGLGraph.out_degrees
Querying batch summary
----------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.batch_size
DGLGraph.batch_num_nodes
DGLGraph.batch_num_edges
Querying sub-graph/parent-graph belonging information
-----------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.parent
Removing nodes and edges Removing nodes and edges
------------------------ ------------------------
...@@ -66,6 +84,8 @@ Transforming graph ...@@ -66,6 +84,8 @@ Transforming graph
DGLGraph.line_graph DGLGraph.line_graph
DGLGraph.reverse DGLGraph.reverse
DGLGraph.readonly DGLGraph.readonly
DGLGraph.flatten
DGLGraph.detach_parent
Converting from/to other format Converting from/to other format
------------------------------- -------------------------------
...@@ -121,3 +141,29 @@ Computing with DGLGraph ...@@ -121,3 +141,29 @@ Computing with DGLGraph
DGLGraph.filter_nodes DGLGraph.filter_nodes
DGLGraph.filter_edges DGLGraph.filter_edges
DGLGraph.to DGLGraph.to
Batch and Unbatch
-------------------
.. autosummary::
:toctree: ../../generated/
batch
unbatch
Mapping between subgraph and parent graph
-----------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.parent_nid
DGLGraph.parent_eid
DGLGraph.map_to_subgraph_nid
Synchronize features between subgraph and parent graph
------------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.copy_from_parent
DGLGraph.copy_to_parent
...@@ -6,7 +6,7 @@ API Reference ...@@ -6,7 +6,7 @@ API Reference
graph graph
heterograph heterograph
batch readout
batch_heterograph batch_heterograph
nn nn
init init
...@@ -17,7 +17,6 @@ API Reference ...@@ -17,7 +17,6 @@ API Reference
sampler sampler
data data
transform transform
subgraph
graph_store graph_store
nodeflow nodeflow
random random
......
.. _apibatch: .. _apibatch:
dgl.batched_graph dgl.readout
================================================== ==================================================
.. currentmodule:: dgl .. currentmodule:: dgl
.. autoclass:: BatchedDGLGraph
Merge and decompose
-------------------
.. autosummary::
:toctree: ../../generated/
batch
unbatch
Query batch summary
----------------------
.. autosummary::
:toctree: ../../generated/
BatchedDGLGraph.batch_size
BatchedDGLGraph.batch_num_nodes
BatchedDGLGraph.batch_num_edges
Graph Readout Graph Readout
------------- -------------
......
.. _apisubgraph:
dgl.subgraph
================================================
.. currentmodule:: dgl.subgraph
.. autoclass:: DGLSubGraph
Mapping between subgraph and parent graph
-----------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLSubGraph.parent_nid
DGLSubGraph.parent_eid
DGLSubGraph.map_to_subgraph_nid
Synchronize features between subgraph and parent graph
------------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLSubGraph.copy_from_parent
DGLSubGraph.copy_to_parent
...@@ -101,12 +101,11 @@ a useful manual for in-depth developers. ...@@ -101,12 +101,11 @@ a useful manual for in-depth developers.
api/python/graph api/python/graph
api/python/heterograph api/python/heterograph
api/python/batch api/python/readout
api/python/batch_heterograph api/python/batch_heterograph
api/python/nn api/python/nn
api/python/function api/python/function
api/python/udf api/python/udf
api/python/subgraph
api/python/traversal api/python/traversal
api/python/propagate api/python/propagate
api/python/transform api/python/transform
......
...@@ -67,7 +67,7 @@ def load_data(args): ...@@ -67,7 +67,7 @@ def load_data(args):
test_dataset = PPIDataset('test') test_dataset = PPIDataset('test')
PPIDataType = namedtuple('PPIDataset', ['train_mask', 'test_mask', PPIDataType = namedtuple('PPIDataset', ['train_mask', 'test_mask',
'val_mask', 'features', 'labels', 'num_labels', 'graph']) 'val_mask', 'features', 'labels', 'num_labels', 'graph'])
G = dgl.BatchedDGLGraph( G = dgl.batch(
[train_dataset.graph, val_dataset.graph, test_dataset.graph], edge_attrs=None, node_attrs=None) [train_dataset.graph, val_dataset.graph, test_dataset.graph], edge_attrs=None, node_attrs=None)
G = G.to_networkx() G = G.to_networkx()
# hack to dodge the potential bugs of to_networkx # hack to dodge the potential bugs of to_networkx
......
...@@ -226,8 +226,8 @@ def collate_molgraphs(data): ...@@ -226,8 +226,8 @@ def collate_molgraphs(data):
------- -------
smiles : list smiles : list
List of smiles List of smiles
bg : BatchedDGLGraph bg : DGLGraph
Batched DGLGraphs The batched DGLGraph.
labels : Tensor of dtype float32 and shape (B, T) labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and Batched datapoint labels. B is len(data) and
T is the number of total tasks. T is the number of total tasks.
......
...@@ -19,10 +19,10 @@ from ._ffi.function import register_func, get_global_func, list_global_func_name ...@@ -19,10 +19,10 @@ from ._ffi.function import register_func, get_global_func, list_global_func_name
from ._ffi.base import DGLError, __version__ from ._ffi.base import DGLError, __version__
from .base import ALL, NTYPE, NID, ETYPE, EID from .base import ALL, NTYPE, NID, ETYPE, EID
from .batched_graph import * from .readout import *
from .batched_heterograph import * from .batched_heterograph import *
from .convert import * from .convert import *
from .graph import DGLGraph from .graph import DGLGraph, batch, unbatch
from .generators import * from .generators import *
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from .nodeflow import * from .nodeflow import *
......
...@@ -12,7 +12,8 @@ from ..._ffi.ndarray import empty ...@@ -12,7 +12,8 @@ from ..._ffi.ndarray import empty
from ... import utils from ... import utils
from ...nodeflow import NodeFlow from ...nodeflow import NodeFlow
from ... import backend as F from ... import backend as F
from ... import subgraph from ...graph import DGLGraph
from ...base import NID, EID
try: try:
import Queue as queue import Queue as queue
...@@ -431,13 +432,17 @@ class LayerSampler(NodeFlowSampler): ...@@ -431,13 +432,17 @@ class LayerSampler(NodeFlowSampler):
nflows = [NodeFlow(self.g, obj) for obj in nfobjs] nflows = [NodeFlow(self.g, obj) for obj in nfobjs]
return nflows return nflows
class EdgeSubgraph(subgraph.DGLSubGraph): class EdgeSubgraph(DGLGraph):
''' The subgraph sampled from an edge sampler. ''' The subgraph sampled from an edge sampler.
A user can access the head nodes and tail nodes of the subgraph directly. A user can access the head nodes and tail nodes of the subgraph directly.
''' '''
def __init__(self, parent, sgi, neg): def __init__(self, parent, sgi, neg):
super(EdgeSubgraph, self).__init__(parent, sgi) super(EdgeSubgraph, self).__init__(graph_data=sgi.graph,
readonly=True,
parent=parent)
self.ndata[NID] = sgi.induced_nodes.tousertensor()
self.edata[EID] = sgi.induced_edges.tousertensor()
self.sgi = sgi self.sgi = sgi
self.neg = neg self.neg = neg
self.head = None self.head = None
...@@ -735,7 +740,9 @@ class EdgeSampler(object): ...@@ -735,7 +740,9 @@ class EdgeSampler(object):
if self._negative_mode == "": if self._negative_mode == "":
# If no negative subgraphs. # If no negative subgraphs.
return [subgraph.DGLSubGraph(self.g, subg) for subg in subgs] return [self.g._create_subgraph(subg,
subg.induced_nodes,
subg.induced_edges) for subg in subgs]
else: else:
rets = [] rets = []
assert len(subgs) % 2 == 0 assert len(subgs) % 2 == 0
......
"""For Graph Serialization""" """For Graph Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
from ..graph import DGLGraph from ..graph import DGLGraph
from ..batched_graph import BatchedDGLGraph
from .._ffi.object import ObjectBase, register_object from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
...@@ -29,7 +28,7 @@ class GraphData(ObjectBase): ...@@ -29,7 +28,7 @@ class GraphData(ObjectBase):
@staticmethod @staticmethod
def create(g: DGLGraph): def create(g: DGLGraph):
"""Create GraphData""" """Create GraphData"""
assert not isinstance(g, BatchedDGLGraph), "BatchedDGLGraph is not supported for serialization" assert g.batch_size == 1, "BatchedDGLGraph is not supported for serialization"
ghandle = g._graph ghandle = g._graph
if len(g.ndata) != 0: if len(g.ndata) != 0:
node_tensors = dict() node_tensors = dict()
......
"""Dataset for stochastic block model.""" """Dataset for stochastic block model."""
import math import math
import os
import pickle
import random import random
import numpy as np import numpy as np
import numpy.random as npr import numpy.random as npr
import scipy as sp import scipy as sp
import networkx as nx
from ..batched_graph import batch from ..graph import DGLGraph, batch
from ..graph import DGLGraph
from ..utils import Index from ..utils import Index
def sbm(n_blocks, block_size, p, q, rng=None): def sbm(n_blocks, block_size, p, q, rng=None):
......
This diff is collapsed.
...@@ -1182,13 +1182,13 @@ def from_edge_list(elist, is_multigraph, readonly): ...@@ -1182,13 +1182,13 @@ def from_edge_list(elist, is_multigraph, readonly):
raise DGLError('Invalid edge list. Nodes must start from 0.') raise DGLError('Invalid edge list. Nodes must start from 0.')
return from_coo(num_nodes, src_ids, dst_ids, is_multigraph, readonly) return from_coo(num_nodes, src_ids, dst_ids, is_multigraph, readonly)
def map_to_subgraph_nid(subgraph, parent_nids): def map_to_subgraph_nid(induced_nodes, parent_nids):
"""Map parent node Ids to the subgraph node Ids. """Map parent node Ids to the subgraph node Ids.
Parameters Parameters
---------- ----------
subgraph: SubgraphIndex induced_nodes: utils.Index
the graph index of a subgraph Induced nodes of the subgraph.
parent_nids: utils.Index parent_nids: utils.Index
Node Ids in the parent graph. Node Ids in the parent graph.
...@@ -1198,7 +1198,7 @@ def map_to_subgraph_nid(subgraph, parent_nids): ...@@ -1198,7 +1198,7 @@ def map_to_subgraph_nid(subgraph, parent_nids):
utils.Index utils.Index
Node Ids in the subgraph. Node Ids in the subgraph.
""" """
return utils.toindex(_CAPI_DGLMapSubgraphNID(subgraph.induced_nodes.todgltensor(), return utils.toindex(_CAPI_DGLMapSubgraphNID(induced_nodes.todgltensor(),
parent_nids.todgltensor())) parent_nids.todgltensor()))
def transform_ids(mapping, ids): def transform_ids(mapping, ids):
......
...@@ -147,7 +147,7 @@ class GetContext(nn.Module): ...@@ -147,7 +147,7 @@ class GetContext(nn.Module):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph or BatchedDGLGraph g : DGLGraph
Constructed DGLGraphs. Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1) node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size. Input node features. V for the number of nodes and N1 for the feature size.
...@@ -202,7 +202,7 @@ class GNNLayer(nn.Module): ...@@ -202,7 +202,7 @@ class GNNLayer(nn.Module):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph or BatchedDGLGraph g : DGLGraph
Constructed DGLGraphs. Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1) node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size. Input node features. V for the number of nodes and N1 for the feature size.
...@@ -248,7 +248,7 @@ class GlobalPool(nn.Module): ...@@ -248,7 +248,7 @@ class GlobalPool(nn.Module):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph or BatchedDGLGraph g : DGLGraph
Constructed DGLGraphs. Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1) node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size. Input node features. V for the number of nodes and N1 for the feature size.
...@@ -326,7 +326,7 @@ class AttentiveFP(nn.Module): ...@@ -326,7 +326,7 @@ class AttentiveFP(nn.Module):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph or BatchedDGLGraph g : DGLGraph
Constructed DGLGraphs. Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1) node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size. Input node features. V for the number of nodes and N1 for the feature size.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment