"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a44d55d87ba3628ac79292fdcaead7fb98fc130b"
Unverified Commit b2b8be25 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[API] update graph store API. (#549)

* add init_ndata and init_edata in DGLGraph.

* adjust SharedMemoryGraph API.

* print warning.

* fix comment.

* update example

* fix.

* fix examples.

* add unit tests.

* add comments.
parent cdfca992
...@@ -140,36 +140,26 @@ class GCNInfer(gluon.Block): ...@@ -140,36 +140,26 @@ class GCNInfer(gluon.Block):
def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, distributed): def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, distributed):
features = g.ndata['features'] n0_feats = g.nodes[0].data['features']
labels = g.ndata['labels'] num_nodes = g.number_of_nodes()
in_feats = features.shape[1] in_feats = n0_feats.shape[1]
g_ctx = features.context g_ctx = n0_feats.context
norm = mx.nd.expand_dims(1./g.in_degrees().astype('float32'), 1) norm = mx.nd.expand_dims(1./g.in_degrees().astype('float32'), 1)
g.ndata['norm'] = norm.as_in_context(g_ctx) g.set_n_repr({'norm': norm.as_in_context(g_ctx)})
degs = g.in_degrees().astype('float32').asnumpy() degs = g.in_degrees().astype('float32').asnumpy()
degs[degs > args.num_neighbors] = args.num_neighbors degs[degs > args.num_neighbors] = args.num_neighbors
g.ndata['subg_norm'] = mx.nd.expand_dims(mx.nd.array(1./degs, ctx=g_ctx), 1) g.set_n_repr({'subg_norm': mx.nd.expand_dims(mx.nd.array(1./degs, ctx=g_ctx), 1)})
n_layers = args.n_layers n_layers = args.n_layers
if distributed: g.update_all(fn.copy_src(src='features', out='m'),
g.dist_update_all(fn.copy_src(src='features', out='m'), fn.sum(msg='m', out='preprocess'),
fn.sum(msg='m', out='preprocess'), lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']}) for i in range(n_layers - 1):
for i in range(n_layers - 1): g.init_ndata('h_{}'.format(i), (num_nodes, args.n_hidden), 'float32')
g.init_ndata('h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32') g.init_ndata('agg_h_{}'.format(i), (num_nodes, args.n_hidden), 'float32')
g.init_ndata('agg_h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32') g.init_ndata('h_{}'.format(n_layers-1), (num_nodes, 2*args.n_hidden), 'float32')
g.init_ndata('h_{}'.format(n_layers-1), (features.shape[0], 2*args.n_hidden), 'float32') g.init_ndata('agg_h_{}'.format(n_layers-1), (num_nodes, 2*args.n_hidden), 'float32')
g.init_ndata('agg_h_{}'.format(n_layers-1), (features.shape[0], 2*args.n_hidden), 'float32')
else:
g.update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
for i in range(n_layers):
g.ndata['h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=g_ctx)
g.ndata['agg_h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=g_ctx)
g.ndata['h_{}'.format(n_layers-1)] = mx.nd.zeros((features.shape[0], 2*args.n_hidden), ctx=g_ctx)
g.ndata['agg_h_{}'.format(n_layers-1)] = mx.nd.zeros((features.shape[0], 2*args.n_hidden), ctx=g_ctx)
model = GCNSampling(in_feats, model = GCNSampling(in_feats,
args.n_hidden, args.n_hidden,
...@@ -220,8 +210,8 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d ...@@ -220,8 +210,8 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d
dests = nf.layer_parent_nid(i+1).as_in_context(g_ctx) dests = nf.layer_parent_nid(i+1).as_in_context(g_ctx)
# TODO we could use DGLGraph.pull to implement this, but the current # TODO we could use DGLGraph.pull to implement this, but the current
# implementation of pull is very slow. Let's manually do it for now. # implementation of pull is very slow. Let's manually do it for now.
g.ndata[agg_history_str][dests] = mx.nd.dot(mx.nd.take(adj, dests), agg = mx.nd.dot(mx.nd.take(adj, dests), g.nodes[:].data['h_{}'.format(i)])
g.ndata['h_{}'.format(i)]) g.set_n_repr({agg_history_str: agg}, dests)
node_embed_names = [['preprocess', 'h_0']] node_embed_names = [['preprocess', 'h_0']]
for i in range(1, n_layers): for i in range(1, n_layers):
...@@ -233,7 +223,7 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d ...@@ -233,7 +223,7 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d
with mx.autograd.record(): with mx.autograd.record():
pred = model(nf) pred = model(nf)
batch_nids = nf.layer_parent_nid(-1) batch_nids = nf.layer_parent_nid(-1)
batch_labels = labels[batch_nids].as_in_context(ctx) batch_labels = g.nodes[batch_nids].data['labels'].as_in_context(ctx)
loss = loss_fcn(pred, batch_labels) loss = loss_fcn(pred, batch_labels)
loss = loss.sum() / len(batch_nids) loss = loss.sum() / len(batch_nids)
...@@ -269,7 +259,7 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d ...@@ -269,7 +259,7 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d
nf.copy_from_parent(node_embed_names=node_embed_names, ctx=ctx) nf.copy_from_parent(node_embed_names=node_embed_names, ctx=ctx)
pred = infer_model(nf) pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1) batch_nids = nf.layer_parent_nid(-1)
batch_labels = labels[batch_nids].as_in_context(ctx) batch_labels = g.nodes[batch_nids].data['labels'].as_in_context(ctx)
num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar() num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
num_tests += nf.layer_size(-1) num_tests += nf.layer_size(-1)
if distributed: if distributed:
......
...@@ -110,13 +110,13 @@ class GCNInfer(gluon.Block): ...@@ -110,13 +110,13 @@ class GCNInfer(gluon.Block):
def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples): def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
in_feats = g.ndata['features'].shape[1] n0_feats = g.nodes[0].data['features']
labels = g.ndata['labels'] in_feats = n0_feats.shape[1]
g_ctx = labels.context g_ctx = n0_feats.context
degs = g.in_degrees().astype('float32').as_in_context(g_ctx) degs = g.in_degrees().astype('float32').as_in_context(g_ctx)
norm = mx.nd.expand_dims(1./degs, 1) norm = mx.nd.expand_dims(1./degs, 1)
g.ndata['norm'] = norm g.set_n_repr({'norm': norm})
model = GCNSampling(in_feats, model = GCNSampling(in_feats,
args.n_hidden, args.n_hidden,
...@@ -159,7 +159,7 @@ def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples): ...@@ -159,7 +159,7 @@ def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
with mx.autograd.record(): with mx.autograd.record():
pred = model(nf) pred = model(nf)
batch_nids = nf.layer_parent_nid(-1) batch_nids = nf.layer_parent_nid(-1)
batch_labels = labels[batch_nids].as_in_context(ctx) batch_labels = g.nodes[batch_nids].data['labels'].as_in_context(ctx)
loss = loss_fcn(pred, batch_labels) loss = loss_fcn(pred, batch_labels)
loss = loss.sum() / len(batch_nids) loss = loss.sum() / len(batch_nids)
...@@ -183,7 +183,7 @@ def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples): ...@@ -183,7 +183,7 @@ def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
nf.copy_from_parent(ctx=ctx) nf.copy_from_parent(ctx=ctx)
pred = infer_model(nf) pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1) batch_nids = nf.layer_parent_nid(-1)
batch_labels = labels[batch_nids].as_in_context(ctx) batch_labels = g.nodes[batch_nids].data['labels'].as_in_context(ctx)
num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar() num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
num_tests += nf.layer_size(-1) num_tests += nf.layer_size(-1)
break break
......
...@@ -179,32 +179,24 @@ class GraphSAGEInfer(gluon.Block): ...@@ -179,32 +179,24 @@ class GraphSAGEInfer(gluon.Block):
def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, distributed): def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, distributed):
features = g.ndata['features'] n0_feats = g.nodes[0].data['features']
labels = g.ndata['labels'] num_nodes = g.number_of_nodes()
in_feats = g.ndata['features'].shape[1] in_feats = n0_feats.shape[1]
g_ctx = features.context g_ctx = n0_feats.context
norm = mx.nd.expand_dims(1./g.in_degrees().astype('float32'), 1) norm = mx.nd.expand_dims(1./g.in_degrees().astype('float32'), 1)
g.ndata['norm'] = norm.as_in_context(g_ctx) g.set_n_repr({'norm': norm.as_in_context(g_ctx)})
degs = g.in_degrees().astype('float32').asnumpy() degs = g.in_degrees().astype('float32').asnumpy()
degs[degs > args.num_neighbors] = args.num_neighbors degs[degs > args.num_neighbors] = args.num_neighbors
g.ndata['subg_norm'] = mx.nd.expand_dims(mx.nd.array(1./degs, ctx=g_ctx), 1) g.set_n_repr({'subg_norm': mx.nd.expand_dims(mx.nd.array(1./degs, ctx=g_ctx), 1)})
n_layers = args.n_layers n_layers = args.n_layers
if distributed: g.update_all(fn.copy_src(src='features', out='m'),
g.dist_update_all(fn.copy_src(src='features', out='m'), fn.sum(msg='m', out='preprocess'),
fn.sum(msg='m', out='preprocess'), lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']}) for i in range(n_layers):
for i in range(n_layers): g.init_ndata('h_{}'.format(i), (num_nodes, args.n_hidden), 'float32')
g.init_ndata('h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32') g.init_ndata('agg_h_{}'.format(i), (num_nodes, args.n_hidden), 'float32')
g.init_ndata('agg_h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32')
else:
g.update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
for i in range(n_layers):
g.ndata['h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=g_ctx)
g.ndata['agg_h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=g_ctx)
model = GraphSAGETrain(in_feats, model = GraphSAGETrain(in_feats,
args.n_hidden, args.n_hidden,
...@@ -255,8 +247,8 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp ...@@ -255,8 +247,8 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp
dests = nf.layer_parent_nid(i+1).as_in_context(g_ctx) dests = nf.layer_parent_nid(i+1).as_in_context(g_ctx)
# TODO we could use DGLGraph.pull to implement this, but the current # TODO we could use DGLGraph.pull to implement this, but the current
# implementation of pull is very slow. Let's manually do it for now. # implementation of pull is very slow. Let's manually do it for now.
g.ndata[agg_history_str][dests] = mx.nd.dot(mx.nd.take(adj, dests), agg = mx.nd.dot(mx.nd.take(adj, dests), g.nodes[:].data['h_{}'.format(i)])
g.ndata['h_{}'.format(i)]) g.set_n_repr({agg_history_str: agg}, dests)
node_embed_names = [['preprocess', 'features', 'h_0']] node_embed_names = [['preprocess', 'features', 'h_0']]
for i in range(1, n_layers): for i in range(1, n_layers):
...@@ -268,7 +260,7 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp ...@@ -268,7 +260,7 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp
with mx.autograd.record(): with mx.autograd.record():
pred = model(nf) pred = model(nf)
batch_nids = nf.layer_parent_nid(-1) batch_nids = nf.layer_parent_nid(-1)
batch_labels = labels[batch_nids].as_in_context(ctx) batch_labels = g.nodes[batch_nids].data['labels'].as_in_context(ctx)
loss = loss_fcn(pred, batch_labels) loss = loss_fcn(pred, batch_labels)
if distributed: if distributed:
loss = loss.sum() / (len(batch_nids) * g.num_workers) loss = loss.sum() / (len(batch_nids) * g.num_workers)
...@@ -308,7 +300,7 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp ...@@ -308,7 +300,7 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp
pred = infer_model(nf) pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1) batch_nids = nf.layer_parent_nid(-1)
batch_labels = labels[batch_nids].as_in_context(ctx) batch_labels = g.nodes[batch_nids].data['labels'].as_in_context(ctx)
num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar() num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
num_tests += nf.layer_size(-1) num_tests += nf.layer_size(-1)
if distributed: if distributed:
......
...@@ -8,7 +8,7 @@ from functools import partial ...@@ -8,7 +8,7 @@ from functools import partial
from collections.abc import MutableMapping from collections.abc import MutableMapping
from ..base import ALL, is_all, DGLError from ..base import ALL, is_all, DGLError, dgl_warning
from .. import backend as F from .. import backend as F
from ..graph import DGLGraph from ..graph import DGLGraph
from .. import utils from .. import utils
...@@ -304,6 +304,7 @@ class SharedMemoryStoreServer(object): ...@@ -304,6 +304,7 @@ class SharedMemoryStoreServer(object):
The port that the server listens to. The port that the server listens to.
""" """
def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port): def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port):
self.server = None
graph_idx = GraphIndex(multigraph=multigraph, readonly=True) graph_idx = GraphIndex(multigraph=multigraph, readonly=True)
indptr, indices = _to_csr(graph_data, edge_dir, multigraph) indptr, indices = _to_csr(graph_data, edge_dir, multigraph)
graph_idx.from_csr_matrix(indptr, indices, edge_dir, _get_graph_path(graph_name)) graph_idx.from_csr_matrix(indptr, indices, edge_dir, _get_graph_path(graph_name))
...@@ -400,8 +401,9 @@ class SharedMemoryStoreServer(object): ...@@ -400,8 +401,9 @@ class SharedMemoryStoreServer(object):
self.server.register_function(all_enter, "all_enter") self.server.register_function(all_enter, "all_enter")
def __del__(self): def __del__(self):
if self.server is not None:
self.server.server_close()
self._graph = None self._graph = None
self.server.server_close()
@property @property
def ndata(self): def ndata(self):
...@@ -436,7 +438,78 @@ class SharedMemoryStoreServer(object): ...@@ -436,7 +438,78 @@ class SharedMemoryStoreServer(object):
self.server.handle_request() self.server.handle_request()
self._graph = None self._graph = None
class SharedMemoryDGLGraph(DGLGraph):
class BaseGraphStore(DGLGraph):
"""The base class of the graph store.
Shared-memory graph store and distributed graph store will be inherited from
this base class. The graph stores only support large read-only graphs. Thus, many of
DGLGraph APIs aren't supported.
Specially, the graph store doesn't support the following methods:
- ndata
- edata
- incidence_matrix
- line_graph
- reverse
"""
def __init__(self,
graph_data=None,
multigraph=False):
super(BaseGraphStore, self).__init__(graph_data, multigraph=multigraph, readonly=True)
@property
def ndata(self):
"""Return the data view of all the nodes.
DGLGraph.ndata is an abbreviation of DGLGraph.nodes[:].data
"""
raise Exception("Graph store doesn't support access data of all nodes.")
@property
def edata(self):
"""Return the data view of all the edges.
DGLGraph.data is an abbreviation of DGLGraph.edges[:].data
See Also
--------
dgl.DGLGraph.edges
"""
raise Exception("Graph store doesn't support access data of all edges.")
def incidence_matrix(self, typestr, ctx=F.cpu()):
"""Return the incidence matrix representation of this graph.
Parameters
----------
typestr : str
Can be either ``in``, ``out`` or ``both``
ctx : context, optional (default=cpu)
The context of returned incidence matrix.
Returns
-------
SparseTensor
The incidence matrix.
"""
raise Exception("Graph store doesn't support creating an incidence matrix.")
def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph.
See :func:`~dgl.transform.line_graph`.
"""
raise Exception("Graph store doesn't support creating an line matrix.")
def reverse(self, share_ndata=False, share_edata=False):
"""Return the reverse of this graph.
See :func:`~dgl.transform.reverse`.
"""
raise Exception("Graph store doesn't support reversing a matrix.")
class SharedMemoryDGLGraph(BaseGraphStore):
"""Shared-memory DGLGraph. """Shared-memory DGLGraph.
This is a client to access data in the shared-memory graph store that has loads This is a client to access data in the shared-memory graph store that has loads
...@@ -461,7 +534,7 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -461,7 +534,7 @@ class SharedMemoryDGLGraph(DGLGraph):
graph_idx = GraphIndex(multigraph=multigraph, readonly=True) graph_idx = GraphIndex(multigraph=multigraph, readonly=True)
graph_idx.from_shared_mem_csr_matrix(_get_graph_path(graph_name), num_nodes, num_edges, edge_dir) graph_idx.from_shared_mem_csr_matrix(_get_graph_path(graph_name), num_nodes, num_edges, edge_dir)
super(SharedMemoryDGLGraph, self).__init__(graph_idx, multigraph=multigraph, readonly=True) super(SharedMemoryDGLGraph, self).__init__(graph_idx, multigraph=multigraph)
self._init_manager = InitializerManager() self._init_manager = InitializerManager()
# map all ndata and edata from the server. # map all ndata and edata from the server.
...@@ -506,13 +579,13 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -506,13 +579,13 @@ class SharedMemoryDGLGraph(DGLGraph):
assert self.number_of_nodes() == shape[0] assert self.number_of_nodes() == shape[0]
data = empty_shared_mem(_get_ndata_path(self._graph_name, ndata_name), False, shape, dtype) data = empty_shared_mem(_get_ndata_path(self._graph_name, ndata_name), False, shape, dtype)
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
self.ndata[ndata_name] = F.zerocopy_from_dlpack(dlpack) self.set_n_repr({ndata_name: F.zerocopy_from_dlpack(dlpack)})
def _init_edata(self, edata_name, shape, dtype): def _init_edata(self, edata_name, shape, dtype):
assert self.number_of_edges() == shape[0] assert self.number_of_edges() == shape[0]
data = empty_shared_mem(_get_edata_path(self._graph_name, edata_name), False, shape, dtype) data = empty_shared_mem(_get_edata_path(self._graph_name, edata_name), False, shape, dtype)
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
self.edata[edata_name] = F.zerocopy_from_dlpack(dlpack) self.set_e_repr({edata_name: F.zerocopy_from_dlpack(dlpack)})
@property @property
def num_workers(self): def num_workers(self):
...@@ -544,7 +617,7 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -544,7 +617,7 @@ class SharedMemoryDGLGraph(DGLGraph):
continue continue
self.proxy.leave_barrier(self._worker_id, bid) self.proxy.leave_barrier(self._worker_id, bid)
def init_ndata(self, ndata_name, shape, dtype): def init_ndata(self, ndata_name, shape, dtype, ctx=F.cpu()):
"""Create node embedding. """Create node embedding.
It first creates the node embedding in the server and maps it to the current process It first creates the node embedding in the server and maps it to the current process
...@@ -559,7 +632,11 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -559,7 +632,11 @@ class SharedMemoryDGLGraph(DGLGraph):
dtype : string dtype : string
The data type of the node embedding. The currently supported data types The data type of the node embedding. The currently supported data types
are "float32" and "int32". are "float32" and "int32".
ctx : DGLContext
The column context.
""" """
if ctx != F.cpu():
raise Exception("graph store only supports CPU context for node data")
init = self._node_frame.get_initializer(ndata_name) init = self._node_frame.get_initializer(ndata_name)
if init is None: if init is None:
self._node_frame._frame._warn_and_set_initializer() self._node_frame._frame._warn_and_set_initializer()
...@@ -568,7 +645,7 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -568,7 +645,7 @@ class SharedMemoryDGLGraph(DGLGraph):
self.proxy.init_ndata(init, ndata_name, shape, dtype) self.proxy.init_ndata(init, ndata_name, shape, dtype)
self._init_ndata(ndata_name, shape, dtype) self._init_ndata(ndata_name, shape, dtype)
def init_edata(self, edata_name, shape, dtype): def init_edata(self, edata_name, shape, dtype, ctx=F.cpu()):
"""Create edge embedding. """Create edge embedding.
It first creates the edge embedding in the server and maps it to the current process It first creates the edge embedding in the server and maps it to the current process
...@@ -583,7 +660,11 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -583,7 +660,11 @@ class SharedMemoryDGLGraph(DGLGraph):
dtype : string dtype : string
The data type of the edge embedding. The currently supported data types The data type of the edge embedding. The currently supported data types
are "float32" and "int32". are "float32" and "int32".
ctx : DGLContext
The column context.
""" """
if ctx != F.cpu():
raise Exception("graph store only supports CPU context for edge data")
init = self._edge_frame.get_initializer(edata_name) init = self._edge_frame.get_initializer(edata_name)
if init is None: if init is None:
self._edge_frame._frame._warn_and_set_initializer() self._edge_frame._frame._warn_and_set_initializer()
...@@ -592,13 +673,56 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -592,13 +673,56 @@ class SharedMemoryDGLGraph(DGLGraph):
self.proxy.init_edata(init, edata_name, shape, dtype) self.proxy.init_edata(init, edata_name, shape, dtype)
self._init_edata(edata_name, shape, dtype) self._init_edata(edata_name, shape, dtype)
def get_n_repr(self, u=ALL):
"""Get node(s) representation.
The returned feature tensor batches multiple node features on the first dimension.
Parameters
----------
u : node, container or tensor
The node(s).
Returns
-------
dict
Representation dict from feature name to feature tensor.
"""
if len(self.node_attr_schemes()) == 0:
return dict()
if is_all(u):
dgl_warning("It may not be safe to access node data of all nodes."
"It's recommended to node data of a subset of nodes directly.")
return dict(self._node_frame)
else:
u = utils.toindex(u)
return self._node_frame.select_rows(u)
def get_e_repr(self, edges=ALL):
"""Get edge(s) representation.
Parameters
----------
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
Returns
-------
dict
Representation dict
"""
if is_all(edges):
dgl_warning("It may not be safe to access edge data of all edges."
"It's recommended to edge data of a subset of edges directly.")
return super(SharedMemoryDGLGraph, self).get_e_repr(edges)
def dist_update_all(self, message_func="default", def update_all(self, message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default"): apply_node_func="default"):
""" Distribute the computation in update_all among all pre-defined workers. """ Distribute the computation in update_all among all pre-defined workers.
dist_update_all requires that all workers invoke this method and will update_all requires that all workers invoke this method and will
return only when all workers finish their own portion of computation. return only when all workers finish their own portion of computation.
The number of workers are pre-defined. If one of them doesn't invoke the method, The number of workers are pre-defined. If one of them doesn't invoke the method,
it won't return because some portion of computation isn't finished. it won't return because some portion of computation isn't finished.
......
...@@ -8,7 +8,7 @@ import dgl ...@@ -8,7 +8,7 @@ import dgl
from .base import ALL, is_all, DGLError from .base import ALL, is_all, DGLError
from . import backend as F from . import backend as F
from . import init from . import init
from .frame import FrameRef, Frame from .frame import FrameRef, Frame, Scheme
from .graph_index import create_graph_index from .graph_index import create_graph_index
from .runtime import ir, scheduler, Runtime from .runtime import ir, scheduler, Runtime
from . import utils from . import utils
...@@ -1556,6 +1556,50 @@ class DGLGraph(DGLBaseGraph): ...@@ -1556,6 +1556,50 @@ class DGLGraph(DGLBaseGraph):
""" """
return self.edges[:].data return self.edges[:].data
def init_ndata(self, ndata_name, shape, dtype, ctx=F.cpu()):
"""Create node embedding.
It first creates the node embedding in the server and maps it to the current process
with shared memory.
Parameters
----------
ndata_name : string
The name of node embedding
shape : tuple
The shape of the node embedding
dtype : string
The data type of the node embedding. The currently supported data types
are "float32" and "int32".
ctx : DGLContext
The column context.
"""
scheme = Scheme(tuple(shape[1:]), F.data_type_dict[dtype])
self._node_frame._frame.add_column(ndata_name, scheme, ctx)
def init_edata(self, edata_name, shape, dtype, ctx=F.cpu()):
"""Create edge embedding.
It first creates the edge embedding in the server and maps it to the current process
with shared memory.
Parameters
----------
edata_name : string
The name of edge embedding
shape : tuple
The shape of the edge embedding
dtype : string
The data type of the edge embedding. The currently supported data types
are "float32" and "int32".
ctx : DGLContext
The column context.
"""
scheme = Scheme(tuple(shape[1:]), F.data_type_dict[dtype])
self._edge_frame._frame.add_column(edata_name, scheme, ctx)
def set_n_repr(self, data, u=ALL, inplace=False): def set_n_repr(self, data, u=ALL, inplace=False):
"""Set node(s) representation. """Set node(s) representation.
...@@ -1692,7 +1736,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -1692,7 +1736,7 @@ class DGLGraph(DGLBaseGraph):
self._edge_frame.update_rows(eid, data, inplace=inplace) self._edge_frame.update_rows(eid, data, inplace=inplace)
def get_e_repr(self, edges=ALL): def get_e_repr(self, edges=ALL):
"""Get node(s) representation. """Get edge(s) representation.
Parameters Parameters
---------- ----------
......
...@@ -26,6 +26,10 @@ def test_graph_creation(): ...@@ -26,6 +26,10 @@ def test_graph_creation():
g.add_nodes(5) g.add_nodes(5)
g.ndata['h'] = 3 * F.ones((5, 2)) g.ndata['h'] = 3 * F.ones((5, 2))
assert F.allclose(3 * F.ones((5, 2)), g.ndata['h']) assert F.allclose(3 * F.ones((5, 2)), g.ndata['h'])
g.init_ndata('h1', (g.number_of_nodes(), 3), 'float32')
assert F.allclose(F.zeros((g.number_of_nodes(), 3)), g.ndata['h1'])
g.init_edata('h2', (g.number_of_edges(), 3), 'float32')
assert F.allclose(F.zeros((g.number_of_edges(), 3)), g.edata['h2'])
def test_create_from_elist(): def test_create_from_elist():
elist = [(2, 1), (1, 0), (2, 0), (3, 0), (0, 2)] elist = [(2, 1), (1, 0), (2, 0), (3, 0), (0, 2)]
......
...@@ -37,12 +37,12 @@ def check_init_func(worker_id, graph_name): ...@@ -37,12 +37,12 @@ def check_init_func(worker_id, graph_name):
coo = csr.tocoo() coo = csr.tocoo()
assert F.array_equal(dst, F.tensor(coo.row)) assert F.array_equal(dst, F.tensor(coo.row))
assert F.array_equal(src, F.tensor(coo.col)) assert F.array_equal(src, F.tensor(coo.col))
assert F.array_equal(g.ndata['feat'][0], F.tensor(np.arange(10), dtype=np.float32)) assert F.array_equal(g.nodes[0].data['feat'], F.tensor(np.arange(10), dtype=np.float32))
assert F.array_equal(g.edata['feat'][0], F.tensor(np.arange(10), dtype=np.float32)) assert F.array_equal(g.edges[0].data['feat'], F.tensor(np.arange(10), dtype=np.float32))
g.init_ndata('test4', (g.number_of_nodes(), 10), 'float32') g.init_ndata('test4', (g.number_of_nodes(), 10), 'float32')
g.init_edata('test4', (g.number_of_edges(), 10), 'float32') g.init_edata('test4', (g.number_of_edges(), 10), 'float32')
g._sync_barrier() g._sync_barrier()
check_array_shared_memory(g, worker_id, [g.ndata['test4'], g.edata['test4']]) check_array_shared_memory(g, worker_id, [g.nodes[:].data['test4'], g.edges[:].data['test4']])
g.destroy() g.destroy()
def server_func(num_workers, graph_name): def server_func(num_workers, graph_name):
...@@ -58,7 +58,7 @@ def server_func(num_workers, graph_name): ...@@ -58,7 +58,7 @@ def server_func(num_workers, graph_name):
g.edata['feat'] = mx.nd.arange(num_edges * 10).reshape((num_edges, 10)) g.edata['feat'] = mx.nd.arange(num_edges * 10).reshape((num_edges, 10))
g.run() g.run()
def test_test_init(): def test_init():
serv_p = Process(target=server_func, args=(2, 'test_graph1')) serv_p = Process(target=server_func, args=(2, 'test_graph1'))
work_p1 = Process(target=check_init_func, args=(0, 'test_graph1')) work_p1 = Process(target=check_init_func, args=(0, 'test_graph1'))
work_p2 = Process(target=check_init_func, args=(1, 'test_graph1')) work_p2 = Process(target=check_init_func, args=(1, 'test_graph1'))
...@@ -75,13 +75,12 @@ def check_update_all_func(worker_id, graph_name): ...@@ -75,13 +75,12 @@ def check_update_all_func(worker_id, graph_name):
print("worker starts") print("worker starts")
g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem", port=rand_port) g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem", port=rand_port)
g._sync_barrier() g._sync_barrier()
g.dist_update_all(fn.copy_src(src='feat', out='m'), g.update_all(fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='preprocess'))
fn.sum(msg='m', out='preprocess'))
adj = g.adjacency_matrix() adj = g.adjacency_matrix()
tmp = mx.nd.dot(adj, g.ndata['feat']) tmp = mx.nd.dot(adj, g.nodes[:].data['feat'])
assert np.all((g.ndata['preprocess'] == tmp).asnumpy()) assert np.all((g.nodes[:].data['preprocess'] == tmp).asnumpy())
g._sync_barrier() g._sync_barrier()
check_array_shared_memory(g, worker_id, [g.ndata['preprocess']]) check_array_shared_memory(g, worker_id, [g.nodes[:].data['preprocess']])
g.destroy() g.destroy()
def test_update_all(): def test_update_all():
...@@ -96,5 +95,5 @@ def test_update_all(): ...@@ -96,5 +95,5 @@ def test_update_all():
work_p2.join() work_p2.join()
if __name__ == '__main__': if __name__ == '__main__':
test_test_init() test_init()
test_update_all() test_update_all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment