"...text-generation-inference.git" did not exist on "cb0a29484d573125336a02dc2191479f18cacabe"
Unverified Commit 93ac29ce authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Unify DGLGraph, BatchedDGLGraph and DGLSubGraph (#1216)



* upd

* upd

* upd

* lint

* fix

* fix test

* fix

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd tutorial

* upd

* upd

* fix kg

* upd doc organization

* refresh test

* upd

* refactor doc

* fix lint
Co-authored-by: default avatarMinjie Wang <minjie.wang@nyu.edu>
parent 8874e830
......@@ -8,6 +8,8 @@ import sys
import pickle
import time
from dgl.base import NID, EID
def SoftRelationPartition(edges, n, threshold=0.05):
"""This partitions a list of edges to n partitions according to their
relation types. For any relation with number of edges larger than the
......@@ -359,7 +361,7 @@ class TrainDataset(object):
return_false_neg=False)
class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
class ChunkNegEdgeSubgraph(dgl.DGLGraph):
"""Wrapper for negative graph
Parameters
......@@ -378,7 +380,11 @@ class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
"""
def __init__(self, subg, num_chunks, chunk_size,
neg_sample_size, neg_head):
super(ChunkNegEdgeSubgraph, self).__init__(subg._parent, subg.sgi)
super(ChunkNegEdgeSubgraph, self).__init__(graph_data=subg.sgi.graph,
readonly=True,
parent=subg._parent)
self.ndata[NID] = subg.sgi.induced_nodes.tousertensor()
self.edata[EID] = subg.sgi.induced_edges.tousertensor()
self.subg = subg
self.num_chunks = num_chunks
self.chunk_size = chunk_size
......
"""MPNN"""
import torch.nn as nn
from dgl import BatchedDGLGraph
from dgl.nn.pytorch import Set2Set
from ..gnn import MPNNGNN
......@@ -77,6 +76,4 @@ class MPNNPredictor(nn.Module):
"""
node_feats = self.gnn(g, node_feats, edge_feats)
graph_feats = self.readout(g, node_feats)
if not isinstance(g, BatchedDGLGraph):
graph_feats = graph_feats.unsqueeze(0)
return self.predict(graph_feats)
......@@ -4,8 +4,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import BatchedDGLGraph
__all__ = ['AttentiveFPReadout']
class GlobalPool(nn.Module):
......@@ -59,10 +57,7 @@ class GlobalPool(nn.Module):
g.ndata['a'] = dgl.softmax_nodes(g, 'z')
g.ndata['hv'] = self.project_nodes(node_feats)
if isinstance(g, BatchedDGLGraph):
g_repr = dgl.sum_nodes(g, 'hv', 'a')
else:
g_repr = dgl.sum_nodes(g, 'hv', 'a').unsqueeze(0)
g_repr = dgl.sum_nodes(g, 'hv', 'a')
context = F.elu(g_repr)
if get_node_weight:
......@@ -121,9 +116,6 @@ class AttentiveFPReadout(nn.Module):
g.ndata['hv'] = node_feats
g_feats = dgl.sum_nodes(g, 'hv')
if not isinstance(g, BatchedDGLGraph):
g_feats = g_feats.unsqueeze(0)
if get_node_weight:
node_weights = []
......
......@@ -2,8 +2,6 @@
import dgl
import torch.nn as nn
from dgl import BatchedDGLGraph
__all__ = ['MLPNodeReadout']
class MLPNodeReadout(nn.Module):
......@@ -64,6 +62,4 @@ class MLPNodeReadout(nn.Module):
elif self.mode == 'sum':
graph_feats = dgl.sum_nodes(g, 'h')
if not isinstance(g, BatchedDGLGraph):
graph_feats = graph_feats.unsqueeze(0)
return graph_feats
......@@ -3,7 +3,6 @@ import dgl
import torch
import torch.nn as nn
from dgl import BatchedDGLGraph
from dgl.nn.pytorch import WeightAndSum
__all__ = ['WeightedSumAndMax']
......@@ -45,10 +44,5 @@ class WeightedSumAndMax(nn.Module):
with bg.local_scope():
bg.ndata['h'] = feats
h_g_max = dgl.max_nodes(bg, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
h_g_max = h_g_max.unsqueeze(0)
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
return h_g
......@@ -93,8 +93,8 @@ def collate_molgraphs(data):
-------
smiles : list
List of smiles
bg : BatchedDGLGraph
Batched DGLGraphs
bg : DGLGraph
The batched DGLGraph.
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
......
......@@ -45,6 +45,24 @@ Querying graph structure
DGLGraph.out_degree
DGLGraph.out_degrees
Querying batch summary
----------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.batch_size
DGLGraph.batch_num_nodes
DGLGraph.batch_num_edges
Querying sub-graph/parent-graph belonging information
-----------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.parent
Removing nodes and edges
------------------------
......@@ -66,6 +84,8 @@ Transforming graph
DGLGraph.line_graph
DGLGraph.reverse
DGLGraph.readonly
DGLGraph.flatten
DGLGraph.detach_parent
Converting from/to other format
-------------------------------
......@@ -121,3 +141,29 @@ Computing with DGLGraph
DGLGraph.filter_nodes
DGLGraph.filter_edges
DGLGraph.to
Batch and Unbatch
-------------------
.. autosummary::
:toctree: ../../generated/
batch
unbatch
Mapping between subgraph and parent graph
-----------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.parent_nid
DGLGraph.parent_eid
DGLGraph.map_to_subgraph_nid
Synchronize features between subgraph and parent graph
------------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.copy_from_parent
DGLGraph.copy_to_parent
......@@ -6,7 +6,7 @@ API Reference
graph
heterograph
batch
readout
batch_heterograph
nn
init
......@@ -17,7 +17,6 @@ API Reference
sampler
data
transform
subgraph
graph_store
nodeflow
random
......
.. _apibatch:
dgl.batched_graph
dgl.readout
==================================================
.. currentmodule:: dgl
.. autoclass:: BatchedDGLGraph
Merge and decompose
-------------------
.. autosummary::
:toctree: ../../generated/
batch
unbatch
Query batch summary
----------------------
.. autosummary::
:toctree: ../../generated/
BatchedDGLGraph.batch_size
BatchedDGLGraph.batch_num_nodes
BatchedDGLGraph.batch_num_edges
Graph Readout
-------------
......
.. _apisubgraph:
dgl.subgraph
================================================
.. currentmodule:: dgl.subgraph
.. autoclass:: DGLSubGraph
Mapping between subgraph and parent graph
-----------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLSubGraph.parent_nid
DGLSubGraph.parent_eid
DGLSubGraph.map_to_subgraph_nid
Synchronize features between subgraph and parent graph
------------------------------------------------------
.. autosummary::
:toctree: ../../generated/
DGLSubGraph.copy_from_parent
DGLSubGraph.copy_to_parent
......@@ -101,12 +101,11 @@ a useful manual for in-depth developers.
api/python/graph
api/python/heterograph
api/python/batch
api/python/readout
api/python/batch_heterograph
api/python/nn
api/python/function
api/python/udf
api/python/subgraph
api/python/traversal
api/python/propagate
api/python/transform
......
......@@ -67,7 +67,7 @@ def load_data(args):
test_dataset = PPIDataset('test')
PPIDataType = namedtuple('PPIDataset', ['train_mask', 'test_mask',
'val_mask', 'features', 'labels', 'num_labels', 'graph'])
G = dgl.BatchedDGLGraph(
G = dgl.batch(
[train_dataset.graph, val_dataset.graph, test_dataset.graph], edge_attrs=None, node_attrs=None)
G = G.to_networkx()
# hack to dodge the potential bugs of to_networkx
......
......@@ -226,8 +226,8 @@ def collate_molgraphs(data):
-------
smiles : list
List of smiles
bg : BatchedDGLGraph
Batched DGLGraphs
bg : DGLGraph
The batched DGLGraph.
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
......
......@@ -19,10 +19,10 @@ from ._ffi.function import register_func, get_global_func, list_global_func_name
from ._ffi.base import DGLError, __version__
from .base import ALL, NTYPE, NID, ETYPE, EID
from .batched_graph import *
from .readout import *
from .batched_heterograph import *
from .convert import *
from .graph import DGLGraph
from .graph import DGLGraph, batch, unbatch
from .generators import *
from .heterograph import DGLHeteroGraph
from .nodeflow import *
......
......@@ -12,7 +12,8 @@ from ..._ffi.ndarray import empty
from ... import utils
from ...nodeflow import NodeFlow
from ... import backend as F
from ... import subgraph
from ...graph import DGLGraph
from ...base import NID, EID
try:
import Queue as queue
......@@ -431,13 +432,17 @@ class LayerSampler(NodeFlowSampler):
nflows = [NodeFlow(self.g, obj) for obj in nfobjs]
return nflows
class EdgeSubgraph(subgraph.DGLSubGraph):
class EdgeSubgraph(DGLGraph):
''' The subgraph sampled from an edge sampler.
A user can access the head nodes and tail nodes of the subgraph directly.
'''
def __init__(self, parent, sgi, neg):
super(EdgeSubgraph, self).__init__(parent, sgi)
super(EdgeSubgraph, self).__init__(graph_data=sgi.graph,
readonly=True,
parent=parent)
self.ndata[NID] = sgi.induced_nodes.tousertensor()
self.edata[EID] = sgi.induced_edges.tousertensor()
self.sgi = sgi
self.neg = neg
self.head = None
......@@ -735,7 +740,9 @@ class EdgeSampler(object):
if self._negative_mode == "":
# If no negative subgraphs.
return [subgraph.DGLSubGraph(self.g, subg) for subg in subgs]
return [self.g._create_subgraph(subg,
subg.induced_nodes,
subg.induced_edges) for subg in subgs]
else:
rets = []
assert len(subgs) % 2 == 0
......
"""For Graph Serialization"""
from __future__ import absolute_import
from ..graph import DGLGraph
from ..batched_graph import BatchedDGLGraph
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F
......@@ -29,7 +28,7 @@ class GraphData(ObjectBase):
@staticmethod
def create(g: DGLGraph):
"""Create GraphData"""
assert not isinstance(g, BatchedDGLGraph), "BatchedDGLGraph is not supported for serialization"
assert g.batch_size == 1, "BatchedDGLGraph is not supported for serialization"
ghandle = g._graph
if len(g.ndata) != 0:
node_tensors = dict()
......
"""Dataset for stochastic block model."""
import math
import os
import pickle
import random
import numpy as np
import numpy.random as npr
import scipy as sp
import networkx as nx
from ..batched_graph import batch
from ..graph import DGLGraph
from ..graph import DGLGraph, batch
from ..utils import Index
def sbm(n_blocks, block_size, p, q, rng=None):
......
This diff is collapsed.
......@@ -1182,13 +1182,13 @@ def from_edge_list(elist, is_multigraph, readonly):
raise DGLError('Invalid edge list. Nodes must start from 0.')
return from_coo(num_nodes, src_ids, dst_ids, is_multigraph, readonly)
def map_to_subgraph_nid(subgraph, parent_nids):
def map_to_subgraph_nid(induced_nodes, parent_nids):
"""Map parent node Ids to the subgraph node Ids.
Parameters
----------
subgraph: SubgraphIndex
the graph index of a subgraph
induced_nodes: utils.Index
Induced nodes of the subgraph.
parent_nids: utils.Index
Node Ids in the parent graph.
......@@ -1198,7 +1198,7 @@ def map_to_subgraph_nid(subgraph, parent_nids):
utils.Index
Node Ids in the subgraph.
"""
return utils.toindex(_CAPI_DGLMapSubgraphNID(subgraph.induced_nodes.todgltensor(),
return utils.toindex(_CAPI_DGLMapSubgraphNID(induced_nodes.todgltensor(),
parent_nids.todgltensor()))
def transform_ids(mapping, ids):
......
......@@ -147,7 +147,7 @@ class GetContext(nn.Module):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
......@@ -202,7 +202,7 @@ class GNNLayer(nn.Module):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
......@@ -248,7 +248,7 @@ class GlobalPool(nn.Module):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
......@@ -326,7 +326,7 @@ class AttentiveFP(nn.Module):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment