Commit 57b07fce authored by Da Zheng's avatar Da Zheng Committed by Minjie Wang
Browse files

[Model] Accelerate SSE (#172)

* lazy eval edge ids.

* parallelize node id lookup.

* fix a bug.

* use mxnet for index.

* use update_all in the subgraph training.

* sample neighbors.

* Revert "parallelize node id lookup."

This reverts commit e661f69bb06cb5a7c246f0e106f245e27800e220.

* update README.

* cache subgraphs.

* support all degrees.

* cache adj in CPU.

* fix a bug in sse.

* print.

* raise error on mutable graphs.

* measure train time per epoch.

* fix a bug in graph_index.

* remove readonly in DGLSubGraph.

* cache subgraph properly.

* accelerate getting adjacency.

* split infer.

* lazy eval edges.

* specify inference mode.

* update for new sampler.

* use new mxnet sampling api.

* fix indent.

* remove profiling code.

* remove mxnet from sampler.

* return a lambda function for graph edges.

* add docs for immutable subgraph.

* Revert "return a lambda function for graph edges."

This reverts commit 0de5d7f100e230c518a3fb8976a6227f474d09ee.

* get parent_eid.
parent 5278220f
# Benchmark SSE on multi-GPUs
# Use a small embedding.
DGLBACKEND=mxnet python3 -m pyinstrument -o prof.out examples/mxnet/sse/sse_batch.py --graph-file ../../data/5_5_csr.nd --n-epochs 1 --lr 0.0005 --batch-size 1024 --use-spmv 1 --num-parallel-subgraphs 32 --gpu 8
# Use a large embedding.
DGLBACKEND=mxnet python3 examples/mxnet/sse/sse_batch.py --graph-file ../../data/5_5_csr.nd --n-epochs 1 --lr 0.0005 --batch-size 2048 --use-spmv 1 --num-parallel-subgraphs 32 --gpu 8 --n-hidden 500
DGLBACKEND=mxnet python3 examples/mxnet/sse/sse_batch.py --graph-file ../../data/5_5_csr.nd --n-epochs 1 --lr 0.0005 --batch-size 1024 --use-spmv --dgl --num-parallel-subgraphs 32 --gpu 1 --num-feats 100 --n-hidden 100
# test convergence
DGLBACKEND=mxnet python3 examples/mxnet/sse/sse_batch.py --dataset "pubmed" --n-epochs 100 --lr 0.005 --batch-size 1024 --use-spmv 1
DGLBACKEND=mxnet python3 examples/mxnet/sse/sse_batch.py --dataset "pubmed" --n-epochs 100 --lr 0.001 --batch-size 1024 --dgl --use-spmv --neigh-expand 4
......@@ -4,6 +4,7 @@ Paper: http://proceedings.mlr.press/v80/dai18a.html
"""
import argparse
import random
import numpy as np
import time
import math
......@@ -52,14 +53,23 @@ class SSEUpdateHidden(gluon.Block):
with self.name_scope():
self.layer = NodeUpdate(n_hidden, activation)
self.dropout = dropout
self.n_hidden = n_hidden
def forward(self, g, vertices):
if vertices is None:
deg = mx.nd.expand_dims(g.in_degrees(np.arange(0, g.number_of_nodes())), 1).astype(np.float32)
deg = mx.nd.expand_dims(g.in_degrees(), 1).astype(np.float32)
feat = g.get_n_repr()['in']
cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
accum = mx.nd.dot(g.adjacency_matrix(), cat) / deg
return self.layer(feat, g.ndata['h'], accum)
batch_size = 100000
num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
ret = mx.nd.empty(shape=(feat.shape[0], self.n_hidden), ctx=feat.context)
for i in range(num_batches):
vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
ret[vs] = self.layer(mx.nd.take(feat, vs),
mx.nd.take(g.ndata['h'], vs),
mx.nd.take(accum, vs))
return ret
else:
deg = mx.nd.expand_dims(g.in_degrees(vertices), 1).astype(np.float32)
# We don't need dropout for inference.
......@@ -80,12 +90,14 @@ class DGLSSEUpdateHidden(gluon.Block):
activation,
dropout,
use_spmv,
inference,
**kwargs):
super(DGLSSEUpdateHidden, self).__init__(**kwargs)
with self.name_scope():
self.layer = DGLNodeUpdate(NodeUpdate(n_hidden, activation))
self.dropout = dropout
self.use_spmv = use_spmv
self.inference = inference
def forward(self, g, vertices):
if self.use_spmv:
......@@ -97,7 +109,7 @@ class DGLSSEUpdateHidden(gluon.Block):
else:
msg_func = gcn_msg
reduce_func = gcn_reduce
deg = mx.nd.expand_dims(g.in_degrees(np.arange(0, g.number_of_nodes())), 1).astype(np.float32)
deg = mx.nd.expand_dims(g.in_degrees(), 1).astype(np.float32)
if vertices is None:
g.update_all(msg_func, reduce_func, None)
if self.use_spmv:
......@@ -107,7 +119,7 @@ class DGLSSEUpdateHidden(gluon.Block):
num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
for i in range(num_batches):
vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
g.apply_nodes(self.layer, vs, inplace=True)
g.apply_nodes(self.layer, vs, inplace=self.inference)
g.ndata.pop('accum')
return g.get_n_repr()['h1']
else:
......@@ -115,14 +127,15 @@ class DGLSSEUpdateHidden(gluon.Block):
if self.dropout:
# TODO here we apply dropout on all vertex representation.
g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
g.pull(vertices, msg_func, reduce_func, None)
g.update_all(msg_func, reduce_func, None)
ctx = g.ndata['accum'].context
if self.use_spmv:
g.ndata.pop('cat')
deg = deg.as_in_context(g.ndata['accum'].context)
deg = deg.as_in_context(ctx)
g.ndata['accum'] = g.ndata['accum'] / deg
g.apply_nodes(self.layer, vertices)
g.apply_nodes(self.layer, vertices, inplace=self.inference)
g.ndata.pop('accum')
return g.ndata['h1'][vertices.as_in_context(g.ndata['h1'].context)]
return mx.nd.take(g.ndata['h1'], vertices.as_in_context(ctx))
class SSEPredict(gluon.Block):
def __init__(self, update_hidden, out_feats, dropout, **kwargs):
......@@ -144,6 +157,42 @@ def copy_to_gpu(subg, ctx):
for key in frame:
subg.ndata[key] = frame[key].as_in_context(ctx)
class CachedSubgraph(object):
def __init__(self, subg, seeds, subg_seeds):
# We can't cache the input subgraph because it contains node frames
# and data frames.
self.subg = dgl.DGLSubGraph(subg._parent, subg._parent_nid, subg._parent_eid,
subg._graph)
self.seeds = seeds
self.subg_seeds = subg_seeds
class CachedSubgraphLoader(object):
def __init__(self, loader, shuffle):
self._loader = loader
self._cached = []
self._shuffle = shuffle
def restart(self):
self._subgraphs = self._cached
self._gen_subgraph = len(self._subgraphs) == 0
random.shuffle(self._subgraphs)
self._cached = []
def __iter__(self):
return self
def __next__(self):
if len(self._subgraphs) > 0:
s = self._subgraphs.pop(0)
subg, seeds, subg_seeds = s.subg, s.seeds, s.subg_seeds
elif self._gen_subgraph:
subg, seeds = self._loader.__next__()
subg_seeds = subg.map_to_subgraph_nid(seeds)
else:
raise StopIteration
self._cached.append(CachedSubgraph(subg, seeds, subg_seeds))
return subg, seeds, subg_seeds
def main(args, data):
if isinstance(data.features, mx.nd.NDArray):
features = data.features
......@@ -154,13 +203,12 @@ def main(args, data):
else:
labels = mx.nd.array(data.labels)
train_size = len(labels) * args.train_percent
train_vs = np.arange(train_size, dtype='int64')
eval_vs = np.arange(train_size, len(labels), dtype='int64')
train_vs = mx.nd.arange(0, train_size, dtype='int64')
eval_vs = mx.nd.arange(train_size, len(labels), dtype='int64')
print("train size: " + str(len(train_vs)))
print("eval size: " + str(len(eval_vs)))
eval_labels = mx.nd.array(data.labels[eval_vs])
eval_labels = mx.nd.take(labels, eval_vs)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
# create the SSE model
......@@ -175,10 +223,10 @@ def main(args, data):
update_hidden_infer = DGLSSEUpdateHidden(args.n_hidden, 'relu',
args.update_dropout, args.use_spmv,
prefix='sse')
inference=True, prefix='sse')
update_hidden_train = DGLSSEUpdateHidden(args.n_hidden, 'relu',
args.update_dropout, args.use_spmv,
prefix='sse')
inference=False, prefix='sse')
if not args.dgl:
update_hidden_infer = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
prefix='sse')
......@@ -209,17 +257,24 @@ def main(args, data):
rets = []
rets.append(all_hidden)
if args.neigh_expand <= 0:
neigh_expand = g.number_of_nodes()
else:
neigh_expand = args.neigh_expand
# initialize graph
dur = []
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
shuffle=True)
sampler = CachedSubgraphLoader(sampler, shuffle=True)
for epoch in range(args.n_epochs):
t0 = time.time()
train_loss = 0
i = 0
num_batches = len(train_vs) / args.batch_size
start1 = time.time()
for subg, seeds in dgl.contrib.sampling.NeighborSampler(g, args.batch_size, g.number_of_nodes(),
neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
shuffle=True):
sampler.restart()
for subg, seeds, subg_seeds in sampler:
subg.copy_from_parent()
losses = []
......@@ -227,10 +282,9 @@ def main(args, data):
ctx = mx.gpu(i % args.gpu)
copy_to_gpu(subg, ctx)
subg_seeds = subg.map_to_subgraph_nid(seeds)
with mx.autograd.record():
logits = model_train(subg, subg_seeds.tousertensor())
batch_labels = mx.nd.array(labels[seeds.asnumpy()], ctx=logits.context)
batch_labels = mx.nd.take(labels, seeds).as_in_context(logits.context)
loss = mx.nd.softmax_cross_entropy(logits, batch_labels)
loss.backward()
losses.append(loss)
......@@ -255,7 +309,7 @@ def main(args, data):
break
# prediction.
logits = model_infer(g, mx.nd.array(eval_vs, dtype=np.int64))
logits = model_infer(g, eval_vs)
eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels)
eval_loss = eval_loss.asnumpy()[0]
......@@ -297,8 +351,7 @@ class GraphData:
csr = mx.nd.sparse.csr_matrix((edge_ids, csr.indices, csr.indptr), shape=csr.shape, dtype=np.int64)
self.graph = MXNetGraph(csr)
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.labels = mx.nd.floor(mx.nd.random.normal(loc=0, scale=10, shape=(csr.shape[0])))
self.num_labels = 10
self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=10, shape=(csr.shape[0])))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
......@@ -330,6 +383,8 @@ if __name__ == '__main__':
parser.add_argument("--dgl", action="store_true")
parser.add_argument("--num-parallel-subgraphs", type=int, default=1,
help="the number of subgraphs to construct in parallel.")
parser.add_argument("--neigh-expand", type=int, default=16,
help="the number of neighbors to sample.")
args = parser.parse_args()
# load and preprocess dataset
......
......@@ -27,6 +27,7 @@ class ImmutableGraphIndex(object):
def __init__(self, in_csr, out_csr):
self._in_csr = in_csr
self._out_csr = out_csr
self._cached_adj = {}
def number_of_nodes(self):
"""Return the number of nodes.
......@@ -232,7 +233,7 @@ class ImmutableGraphIndex(object):
# stores the edge Ids of the original graph.
csr = mx.nd.contrib.dgl_subgraph(self._in_csr, v, return_mapping=True)
induced_nodes = v
induced_edges = csr[1].data
induced_edges = lambda: csr[1].data
return ImmutableGraphIndex(csr[0], None), induced_nodes, induced_edges
def node_subgraphs(self, vs_arr):
......@@ -256,7 +257,7 @@ class ImmutableGraphIndex(object):
res = mx.nd.contrib.dgl_subgraph(self._in_csr, *vs_arr, return_mapping=True)
in_csrs = res[0:len(vs_arr)]
induced_nodes = vs_arr
induced_edges = [e.data for e in res[len(vs_arr):]]
induced_edges = [lambda: e.data for e in res[len(vs_arr):]]
assert len(in_csrs) == len(induced_nodes)
assert len(in_csrs) == len(induced_edges)
gis = []
......@@ -279,11 +280,11 @@ class ImmutableGraphIndex(object):
raise NotImplementedError
num_nodes = []
num_subgs = len(seed_ids)
res = mx.nd.contrib.neighbor_sample(g, *seed_ids, num_hops=num_hops,
num_neighbor=expand_factor,
max_num_vertices=max_subgraph_size)
res = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(g, *seed_ids, num_hops=num_hops,
num_neighbor=expand_factor,
max_num_vertices=max_subgraph_size)
vertices, subgraphs = res[0:num_subgs], res[num_subgs:len(res)]
vertices, subgraphs = res[0:num_subgs], res[num_subgs:(2*num_subgs)]
num_nodes = [subg_v[-1].asnumpy()[0] for subg_v in vertices]
inputs = []
......@@ -298,7 +299,7 @@ class ImmutableGraphIndex(object):
elif neighbor_type == 'out':
gis = [ImmutableGraphIndex(None, csr) for csr in compacts]
parent_nodes = [v[0:size] for v, size in zip(vertices, num_nodes)]
parent_edges = [e.data for e in subgraphs]
parent_edges = [lambda: e.data for e in subgraphs]
return gis, parent_nodes, parent_edges
def adjacency_matrix(self, transpose, ctx):
......@@ -326,11 +327,7 @@ class ImmutableGraphIndex(object):
mat = self._out_csr
else:
mat = self._in_csr
indices = mat.indices
indptr = mat.indptr
data = mx.nd.ones(indices.shape, dtype=np.float32, ctx=ctx)
return mx.nd.sparse.csr_matrix((data, indices, indptr), shape=mat.shape)
return mx.nd.contrib.dgl_adjacency(mat.as_in_context(ctx))
def from_coo_matrix(self, out_coo):
"""construct the graph index from a SciPy coo matrix.
......
......@@ -12,6 +12,8 @@ class NSSubgraphLoader(object):
neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, max_subgraph_size=None):
self._g = g
if not g._graph.is_readonly():
raise NotImplementedError("subgraph loader only support read-only graphs.")
self._batch_size = batch_size
self._expand_factor = expand_factor
self._num_hops = num_hops
......@@ -51,7 +53,7 @@ class NSSubgraphLoader(object):
self._num_hops, self._neighbor_type,
self._node_prob, self._max_subgraph_size)
subgraphs = [DGLSubGraph(self._g, i.induced_nodes, i.induced_edges, \
i, readonly=self._g._readonly) for i in sgi]
i) for i in sgi]
self._subgraphs.extend(subgraphs)
self._seed_ids.extend(seed_ids)
......
......@@ -575,7 +575,7 @@ class DGLGraph(object):
"""
return self._graph.in_degree(v)
def in_degrees(self, v):
def in_degrees(self, v=ALL):
"""Return the in degrees of the nodes.
Parameters
......@@ -592,7 +592,8 @@ class DGLGraph(object):
--------
in_degree
"""
v = utils.toindex(v)
if not is_all(v):
v = utils.toindex(v)
return self._graph.in_degrees(v).tousertensor()
def out_degree(self, v):
......@@ -614,7 +615,7 @@ class DGLGraph(object):
"""
return self._graph.out_degree(v)
def out_degrees(self, v):
def out_degrees(self, v=ALL):
"""Return the out degrees of the nodes.
Parameters
......@@ -631,7 +632,8 @@ class DGLGraph(object):
--------
out_degree
"""
v = utils.toindex(v)
if not is_all(v):
v = utils.toindex(v)
return self._graph.out_degrees(v).tousertensor()
def to_networkx(self, node_attrs=None, edge_attrs=None):
......@@ -1396,8 +1398,7 @@ class DGLGraph(object):
"""
induced_nodes = utils.toindex(nodes)
sgi = self._graph.node_subgraph(induced_nodes)
return dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges,
sgi, readonly=self._readonly)
return dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
def subgraphs(self, nodes):
"""Generate the subgraphs among the given nodes.
......@@ -1419,7 +1420,7 @@ class DGLGraph(object):
induced_nodes = [utils.toindex(n) for n in nodes]
sgis = self._graph.node_subgraphs(induced_nodes)
return [dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges,
sgi, readonly=self._readonly) for sgi in sgis]
sgi) for sgi in sgis]
def edge_subgraph(self, edges):
"""Generate the subgraph among the given edges.
......
......@@ -7,7 +7,7 @@ import scipy
from ._ffi.base import c_array
from ._ffi.function import _init_api
from .base import DGLError
from .base import DGLError, is_all
from . import backend as F
from . import utils
from .immutable_graph_index import create_immutable_graph_index
......@@ -84,6 +84,16 @@ class GraphIndex(object):
"""
return bool(_CAPI_DGLGraphIsMultigraph(self._handle))
def is_readonly(self):
"""Indicate whether the graph index is read-only.
Returns
-------
bool
True if it is a read-only graph, False otherwise.
"""
return False
def number_of_nodes(self):
"""Return the number of nodes.
......@@ -385,6 +395,9 @@ class GraphIndex(object):
int
The in degree array.
"""
if is_all(v):
v = np.arange(0, self.number_of_nodes(), dtype=np.int64)
v = utils.toindex(v)
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphInDegrees(self._handle, v_array))
......@@ -416,6 +429,9 @@ class GraphIndex(object):
int
The out degree array.
"""
if is_all(v):
v = np.arange(0, self.number_of_nodes(), dtype=np.int64)
v = utils.toindex(v)
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphOutDegrees(self._handle, v_array))
......
......@@ -8,6 +8,7 @@ import scipy.sparse as sp
from ._ffi.function import _init_api
from . import backend as F
from . import utils
from .base import ALL, is_all
class ImmutableGraphIndex(object):
"""Graph index object on immutable graphs.
......@@ -73,6 +74,16 @@ class ImmutableGraphIndex(object):
# Immutable graph doesn't support multi-edge.
return False
def is_readonly(self):
"""Indicate whether the graph index is read-only.
Returns
-------
bool
True if it is a read-only graph, False otherwise.
"""
return True
def number_of_nodes(self):
"""Return the number of nodes.
......@@ -353,9 +364,12 @@ class ImmutableGraphIndex(object):
int
The in degree array.
"""
v_array = v.tousertensor()
deg = self._get_in_degree()
return utils.toindex(F.gather_row(deg, v_array))
if is_all(v):
return utils.toindex(deg)
else:
v_array = v.tousertensor()
return utils.toindex(F.gather_row(deg, v_array))
def out_degree(self, v):
"""Return the out degree of the node.
......@@ -386,9 +400,12 @@ class ImmutableGraphIndex(object):
int
The out degree array.
"""
v_array = v.tousertensor()
deg = self._get_out_degree()
return utils.toindex(F.gather_row(deg, v_array))
if is_all(v):
return utils.toindex(deg)
else:
v_array = v.tousertensor()
return utils.toindex(F.gather_row(deg, v_array))
def node_subgraph(self, v):
"""Return the induced node subgraph.
......@@ -405,9 +422,7 @@ class ImmutableGraphIndex(object):
"""
v = v.tousertensor()
gi, induced_n, induced_e = self._sparse.node_subgraph(v)
induced_nodes = utils.toindex(induced_n)
induced_edges = utils.toindex(induced_e)
return ImmutableSubgraphIndex(gi, self, induced_nodes, induced_edges)
return ImmutableSubgraphIndex(gi, self, induced_n, induced_e)
def node_subgraphs(self, vs_arr):
"""Return the induced node subgraphs.
......@@ -424,8 +439,6 @@ class ImmutableGraphIndex(object):
"""
vs_arr = [v.tousertensor() for v in vs_arr]
gis, induced_nodes, induced_edges = self._sparse.node_subgraphs(vs_arr)
induced_nodes = [utils.toindex(v) for v in induced_nodes]
induced_edges = [utils.toindex(e) for e in induced_edges]
return [ImmutableSubgraphIndex(gi, self, induced_n,
induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]
......@@ -439,7 +452,6 @@ class ImmutableGraphIndex(object):
node_prob,
max_subgraph_size)
induced_nodes = [utils.toindex(v) for v in induced_nodes]
induced_edges = [utils.toindex(e) for e in induced_edges]
return [ImmutableSubgraphIndex(gi, self, induced_n,
induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]
......@@ -543,6 +555,19 @@ class ImmutableGraphIndex(object):
raise Exception('immutable graph doesn\'t support line_graph')
class ImmutableSubgraphIndex(ImmutableGraphIndex):
"""Graph index for an immutable subgraph.
Parameters
----------
backend_sparse : a sparse matrix from the backend framework.
The sparse matrix that represents a subgraph.
paranet : GraphIndex
The parent graph index.
induced_nodes : tensor
The parent node ids in this subgraph.
induced_edges : a lambda function that returns a tensor
The parent edge ids in this subgraph.
"""
def __init__(self, backend_sparse, parent, induced_nodes, induced_edges):
super(ImmutableSubgraphIndex, self).__init__(backend_sparse)
......@@ -552,11 +577,25 @@ class ImmutableSubgraphIndex(ImmutableGraphIndex):
@property
def induced_edges(self):
return self._induced_edges
"""Return parent edge ids.
Returns
-------
A lambda function that returns utils.Index
The parent edge ids.
"""
return lambda: utils.toindex(self._induced_edges())
@property
def induced_nodes(self):
return self._induced_nodes
"""Return parent node ids.
Returns
-------
utils.Index
The parent node ids.
"""
return utils.toindex(self._induced_nodes)
def create_immutable_graph_index(graph_data=None):
"""Create a graph index object.
......
......@@ -43,11 +43,10 @@ class DGLSubGraph(DGLGraph):
The graph index.
shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph.
readonly : bool, optional
Whether the graph structure is read-only (default: False).
"""
def __init__(self, parent, parent_nid, parent_eid, graph_idx, shared=False, readonly=False):
super(DGLSubGraph, self).__init__(graph_data=graph_idx, readonly=readonly)
def __init__(self, parent, parent_nid, parent_eid, graph_idx, shared=False):
super(DGLSubGraph, self).__init__(graph_data=graph_idx,
readonly=graph_idx.is_readonly())
self._parent = parent
self._parent_nid = parent_nid
self._parent_eid = parent_eid
......@@ -79,6 +78,15 @@ class DGLSubGraph(DGLGraph):
"""
return self._parent_nid.tousertensor()
def _get_parent_eid(self):
# The parent eid might be lazily evaluated and thus may not
# be an index. Instead, it's a lambda function that returns
# an index.
if isinstance(self._parent_eid, utils.Index):
return self._parent_eid
else:
return self._parent_eid()
@property
def parent_eid(self):
"""Get the parent edge ids.
......@@ -91,7 +99,7 @@ class DGLSubGraph(DGLGraph):
Tensor
The parent edge id array.
"""
return self._parent_eid.tousertensor()
return self._get_parent_eid().tousertensor()
def copy_to_parent(self, inplace=False):
"""Write node/edge features to the parent graph.
......@@ -103,8 +111,9 @@ class DGLSubGraph(DGLGraph):
"""
self._parent._node_frame.update_rows(
self._parent_nid, self._node_frame, inplace=inplace)
self._parent._edge_frame.update_rows(
self._parent_eid, self._edge_frame, inplace=inplace)
if self._parent._edge_frame.num_rows != 0:
self._parent._edge_frame.update_rows(
self._get_parent_eid(), self._edge_frame, inplace=inplace)
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
......@@ -116,7 +125,7 @@ class DGLSubGraph(DGLGraph):
self._parent._node_frame[self._parent_nid]))
if self._parent._edge_frame.num_rows != 0:
self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[self._parent_eid]))
self._parent._edge_frame[self._get_parent_eid()]))
def map_to_subgraph_nid(self, parent_vids):
return map_to_subgraph_nid(self._graph, utils.toindex(parent_vids))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment