Unverified Commit cdfca992 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFix] Improve multi-processing training (#526)

* fix.

* add comment.

* remove.

* temp fix.

* initialize for shared memory.

* fix graphsage.

* fix gcn.

* add more unit tests.

* add more tests.

* avoid creating shared-memory exclusively.

* redefine remote initializer.

* improve initializer.

* fix unit test.

* fix lint.

* fix lint.

* initialize data in the graph store server properly.

* fix test.

* fix test.

* fix test.

* small fix.

* add comments.

* cleanup server.

* test graph store with a random port.

* print.

* print to stderr.

* test1

* test2

* remove comment.

* adjust the initializer signature.
parent 4b761571
...@@ -203,6 +203,11 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d ...@@ -203,6 +203,11 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d
dur = [] dur = []
adj = g.adjacency_matrix().as_in_context(g_ctx) adj = g.adjacency_matrix().as_in_context(g_ctx)
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
start = time.time()
if distributed:
msg_head = "Worker {:d}, epoch {:d}".format(g.worker_id, epoch)
else:
msg_head = "epoch {:d}".format(epoch)
for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size, for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors, args.num_neighbors,
neighbor_type='in', neighbor_type='in',
...@@ -239,6 +244,8 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d ...@@ -239,6 +244,8 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d
node_embed_names.append([]) node_embed_names.append([])
nf.copy_to_parent(node_embed_names=node_embed_names) nf.copy_to_parent(node_embed_names=node_embed_names)
mx.nd.waitall()
print(msg_head + ': training takes ' + str(time.time() - start))
infer_params = infer_model.collect_params() infer_params = infer_model.collect_params()
...@@ -249,21 +256,25 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d ...@@ -249,21 +256,25 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, d
num_acc = 0. num_acc = 0.
num_tests = 0 num_tests = 0
for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size, if not distributed or g.worker_id == 0:
g.number_of_nodes(), for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
neighbor_type='in', g.number_of_nodes(),
num_hops=n_layers, neighbor_type='in',
seed_nodes=test_nid): num_hops=n_layers,
node_embed_names = [['preprocess']] seed_nodes=test_nid):
for i in range(n_layers): node_embed_names = [['preprocess']]
node_embed_names.append(['norm']) for i in range(n_layers):
node_embed_names.append(['norm'])
nf.copy_from_parent(node_embed_names=node_embed_names, ctx=ctx)
pred = infer_model(nf) nf.copy_from_parent(node_embed_names=node_embed_names, ctx=ctx)
batch_nids = nf.layer_parent_nid(-1) pred = infer_model(nf)
batch_labels = labels[batch_nids].as_in_context(ctx) batch_nids = nf.layer_parent_nid(-1)
num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar() batch_labels = labels[batch_nids].as_in_context(ctx)
num_tests += nf.layer_size(-1) num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
break num_tests += nf.layer_size(-1)
if distributed:
print("Test Accuracy {:.4f}". format(num_acc/num_tests)) g._sync_barrier()
print("Test Accuracy {:.4f}". format(num_acc/num_tests))
break
elif distributed:
g._sync_barrier()
...@@ -282,6 +282,7 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp ...@@ -282,6 +282,7 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp
node_embed_names.append([]) node_embed_names.append([])
nf.copy_to_parent(node_embed_names=node_embed_names) nf.copy_to_parent(node_embed_names=node_embed_names)
mx.nd.waitall()
print(msg_head + ': training takes ' + str(time.time() - start)) print(msg_head + ': training takes ' + str(time.time() - start))
infer_params = infer_model.collect_params() infer_params = infer_model.collect_params()
...@@ -294,7 +295,6 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp ...@@ -294,7 +295,6 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp
num_tests = 0 num_tests = 0
if not distributed or g.worker_id == 0: if not distributed or g.worker_id == 0:
start = time.time()
for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size, for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(), g.number_of_nodes(),
neighbor_type='in', neighbor_type='in',
......
...@@ -14,6 +14,8 @@ from graphsage_cv import graphsage_cv_train ...@@ -14,6 +14,8 @@ from graphsage_cv import graphsage_cv_train
def main(args): def main(args):
g = dgl.contrib.graph_store.create_graph_from_store(args.graph_name, "shared_mem") g = dgl.contrib.graph_store.create_graph_from_store(args.graph_name, "shared_mem")
# We need to set random seed here. Otherwise, all processes have the same mini-batches.
mx.random.seed(g.worker_id)
features = g.ndata['features'] features = g.ndata['features']
labels = g.ndata['labels'] labels = g.ndata['labels']
train_mask = g.ndata['train_mask'] train_mask = g.ndata['train_mask']
......
...@@ -4,6 +4,7 @@ import scipy ...@@ -4,6 +4,7 @@ import scipy
from xmlrpc.server import SimpleXMLRPCServer from xmlrpc.server import SimpleXMLRPCServer
import xmlrpc.client import xmlrpc.client
import numpy as np import numpy as np
from functools import partial
from collections.abc import MutableMapping from collections.abc import MutableMapping
...@@ -15,6 +16,7 @@ from ..graph_index import GraphIndex, create_graph_index ...@@ -15,6 +16,7 @@ from ..graph_index import GraphIndex, create_graph_index
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import ndarray as nd from .. import ndarray as nd
from ..init import zero_initializer
def _get_ndata_path(graph_name, ndata_name): def _get_ndata_path(graph_name, ndata_name):
return "/" + graph_name + "_node_" + ndata_name return "/" + graph_name + "_node_" + ndata_name
...@@ -213,6 +215,72 @@ class BarrierManager(object): ...@@ -213,6 +215,72 @@ class BarrierManager(object):
if self.barriers[barrier_id].all_leave(): if self.barriers[barrier_id].all_leave():
del self.barriers[barrier_id] del self.barriers[barrier_id]
def shared_mem_zero_initializer(shape, dtype, name): # pylint: disable=unused-argument
"""Zero feature initializer in shared memory
"""
data = empty_shared_mem(name, True, shape, dtype)
dlpack = data.to_dlpack()
arr = F.zerocopy_from_dlpack(dlpack)
arr[:] = 0
return arr
class InitializerManager(object):
"""Manage initializer.
We need to convert built-in frame initializer to strings
and send them to the graph store server through RPC.
Through the conversion, we need to convert local built-in initializer
to shared-memory initializer.
"""
# Map the built-in initializer functions to strings.
_fun2str = {
zero_initializer: 'zero',
}
# Map the strings to built-in initializer functions.
_str2fun = {
'zero': shared_mem_zero_initializer,
}
def serialize(self, init):
"""Convert the initializer function to string.
Parameters
----------
init : callable
the initializer function.
Returns
------
string
The name of the built-in initializer function.
"""
if init in self._fun2str:
return self._fun2str[init]
else:
raise Exception("Shared-memory graph store doesn't support user's initializer")
def deserialize(self, init):
"""Convert the string to the initializer function.
Parameters
----------
init : string
the name of the initializer function
Returns
-------
callable
The shared-memory initializer function.
"""
if init in self._str2fun:
return self._str2fun[init]
else:
raise Exception("Shared-memory graph store doesn't support initializer "
+ str(init))
class SharedMemoryStoreServer(object): class SharedMemoryStoreServer(object):
"""The graph store server. """The graph store server.
...@@ -247,6 +315,7 @@ class SharedMemoryStoreServer(object): ...@@ -247,6 +315,7 @@ class SharedMemoryStoreServer(object):
self._registered_nworkers = 0 self._registered_nworkers = 0
self._barrier = BarrierManager(num_workers) self._barrier = BarrierManager(num_workers)
self._init_manager = InitializerManager()
# RPC command: register a graph to the graph store server. # RPC command: register a graph to the graph store server.
def register(graph_name): def register(graph_name):
...@@ -265,29 +334,29 @@ class SharedMemoryStoreServer(object): ...@@ -265,29 +334,29 @@ class SharedMemoryStoreServer(object):
self._graph.is_multigraph, edge_dir self._graph.is_multigraph, edge_dir
# RPC command: initialize node embedding in the server. # RPC command: initialize node embedding in the server.
def init_ndata(ndata_name, shape, dtype): def init_ndata(init, ndata_name, shape, dtype):
if ndata_name in self._graph.ndata: if ndata_name in self._graph.ndata:
ndata = self._graph.ndata[ndata_name] ndata = self._graph.ndata[ndata_name]
assert np.all(ndata.shape == tuple(shape)) assert np.all(ndata.shape == tuple(shape))
return 0 return 0
assert self._graph.number_of_nodes() == shape[0] assert self._graph.number_of_nodes() == shape[0]
data = empty_shared_mem(_get_ndata_path(graph_name, ndata_name), True, shape, dtype) init = self._init_manager.deserialize(init)
dlpack = data.to_dlpack() data = init(shape, dtype, _get_ndata_path(graph_name, ndata_name))
self._graph.ndata[ndata_name] = F.zerocopy_from_dlpack(dlpack) self._graph.ndata[ndata_name] = data
return 0 return 0
# RPC command: initialize edge embedding in the server. # RPC command: initialize edge embedding in the server.
def init_edata(edata_name, shape, dtype): def init_edata(init, edata_name, shape, dtype):
if edata_name in self._graph.edata: if edata_name in self._graph.edata:
edata = self._graph.edata[edata_name] edata = self._graph.edata[edata_name]
assert np.all(edata.shape == tuple(shape)) assert np.all(edata.shape == tuple(shape))
return 0 return 0
assert self._graph.number_of_edges() == shape[0] assert self._graph.number_of_edges() == shape[0]
data = empty_shared_mem(_get_edata_path(graph_name, edata_name), True, shape, dtype) init = self._init_manager.deserialize(init)
dlpack = data.to_dlpack() data = init(shape, dtype, _get_edata_path(graph_name, edata_name))
self._graph.edata[edata_name] = F.zerocopy_from_dlpack(dlpack) self._graph.edata[edata_name] = data
return 0 return 0
# RPC command: get the names of all node embeddings. # RPC command: get the names of all node embeddings.
...@@ -332,6 +401,7 @@ class SharedMemoryStoreServer(object): ...@@ -332,6 +401,7 @@ class SharedMemoryStoreServer(object):
def __del__(self): def __del__(self):
self._graph = None self._graph = None
self.server.server_close()
@property @property
def ndata(self): def ndata(self):
...@@ -392,6 +462,7 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -392,6 +462,7 @@ class SharedMemoryDGLGraph(DGLGraph):
graph_idx = GraphIndex(multigraph=multigraph, readonly=True) graph_idx = GraphIndex(multigraph=multigraph, readonly=True)
graph_idx.from_shared_mem_csr_matrix(_get_graph_path(graph_name), num_nodes, num_edges, edge_dir) graph_idx.from_shared_mem_csr_matrix(_get_graph_path(graph_name), num_nodes, num_edges, edge_dir)
super(SharedMemoryDGLGraph, self).__init__(graph_idx, multigraph=multigraph, readonly=True) super(SharedMemoryDGLGraph, self).__init__(graph_idx, multigraph=multigraph, readonly=True)
self._init_manager = InitializerManager()
# map all ndata and edata from the server. # map all ndata and edata from the server.
ndata_infos = self.proxy.list_ndata() ndata_infos = self.proxy.list_ndata()
...@@ -404,29 +475,28 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -404,29 +475,28 @@ class SharedMemoryDGLGraph(DGLGraph):
# Set the ndata and edata initializers. # Set the ndata and edata initializers.
# so that when a new node/edge embedding is created, it'll be created on the server as well. # so that when a new node/edge embedding is created, it'll be created on the server as well.
def node_initializer(name, arr):
shape = F.shape(arr) # These two functions create initialized tensors on the server.
dtype = np.dtype(F.dtype(arr)).name def node_initializer(init, name, shape, dtype, ctx):
self.proxy.init_ndata(name, shape, dtype) init = self._init_manager.serialize(init)
dtype = np.dtype(dtype).name
self.proxy.init_ndata(init, name, shape, dtype)
data = empty_shared_mem(_get_ndata_path(self._graph_name, name), data = empty_shared_mem(_get_ndata_path(self._graph_name, name),
False, shape, dtype) False, shape, dtype)
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
arr1 = F.zerocopy_from_dlpack(dlpack) return F.zerocopy_from_dlpack(dlpack)
arr1[:] = arr def edge_initializer(init, name, shape, dtype, ctx):
return arr1 init = self._init_manager.serialize(init)
def edge_initializer(name, arr): dtype = np.dtype(dtype).name
shape = F.shape(arr) self.proxy.init_edata(init, name, shape, dtype)
dtype = np.dtype(F.dtype(arr)).name
self.proxy.init_edata(name, shape, dtype)
data = empty_shared_mem(_get_edata_path(self._graph_name, name), data = empty_shared_mem(_get_edata_path(self._graph_name, name),
False, shape, dtype) False, shape, dtype)
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
arr1 = F.zerocopy_from_dlpack(dlpack) return F.zerocopy_from_dlpack(dlpack)
arr1[:] = arr
return arr1 self._node_frame.set_remote_init_builder(lambda init, name: partial(node_initializer, init, name))
self._node_frame.set_remote_initializer(node_initializer) self._edge_frame.set_remote_init_builder(lambda init, name: partial(edge_initializer, init, name))
self._edge_frame.set_remote_initializer(edge_initializer) self._msg_frame.set_remote_init_builder(lambda init, name: partial(edge_initializer, init, name))
self._msg_frame.set_remote_initializer(edge_initializer)
def __del__(self): def __del__(self):
if self.proxy is not None: if self.proxy is not None:
...@@ -490,7 +560,12 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -490,7 +560,12 @@ class SharedMemoryDGLGraph(DGLGraph):
The data type of the node embedding. The currently supported data types The data type of the node embedding. The currently supported data types
are "float32" and "int32". are "float32" and "int32".
""" """
self.proxy.init_ndata(ndata_name, shape, dtype) init = self._node_frame.get_initializer(ndata_name)
if init is None:
self._node_frame._frame._warn_and_set_initializer()
init = self._node_frame.get_initializer(ndata_name)
init = self._init_manager.serialize(init)
self.proxy.init_ndata(init, ndata_name, shape, dtype)
self._init_ndata(ndata_name, shape, dtype) self._init_ndata(ndata_name, shape, dtype)
def init_edata(self, edata_name, shape, dtype): def init_edata(self, edata_name, shape, dtype):
...@@ -509,7 +584,12 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -509,7 +584,12 @@ class SharedMemoryDGLGraph(DGLGraph):
The data type of the edge embedding. The currently supported data types The data type of the edge embedding. The currently supported data types
are "float32" and "int32". are "float32" and "int32".
""" """
self.proxy.init_edata(edata_name, shape, dtype) init = self._edge_frame.get_initializer(edata_name)
if init is None:
self._edge_frame._frame._warn_and_set_initializer()
init = self._edge_frame.get_initializer(edata_name)
init = self._init_manager.serialize(init)
self.proxy.init_edata(init, edata_name, shape, dtype)
self._init_edata(edata_name, shape, dtype) self._init_edata(edata_name, shape, dtype)
......
...@@ -210,7 +210,7 @@ class Frame(MutableMapping): ...@@ -210,7 +210,7 @@ class Frame(MutableMapping):
# If is none, then a warning will be raised # If is none, then a warning will be raised
# in the first call and zero initializer will be used later. # in the first call and zero initializer will be used later.
self._initializers = {} # per-column initializers self._initializers = {} # per-column initializers
self._remote_initializer = None self._remote_init_builder = None
self._default_initializer = None self._default_initializer = None
def _warn_and_set_initializer(self): def _warn_and_set_initializer(self):
...@@ -252,17 +252,34 @@ class Frame(MutableMapping): ...@@ -252,17 +252,34 @@ class Frame(MutableMapping):
else: else:
self._initializers[column] = initializer self._initializers[column] = initializer
def set_remote_initializer(self, initializer): def set_remote_init_builder(self, builder):
"""Set the remote initializer when a column is added to the frame. """Set an initializer builder to create a remote initializer for a new column to a frame.
Initializer is a callable that returns a tensor given a local tensor and tensor name. The builder is a callable that returns an initializer. The returned initializer
is also a callable that returns a tensor given a local tensor and tensor name.
Parameters Parameters
---------- ----------
initializer : callable builder : callable
The initializer. The builder to construct a remote initializer.
"""
self._remote_init_builder = builder
def get_remote_initializer(self, name):
"""Get a remote initializer.
Parameters
----------
name : string
The column name.
""" """
self._remote_initializer = initializer if self._remote_init_builder is None:
return None
if self.get_initializer(name) is None:
self._warn_and_set_initializer()
initializer = self.get_initializer(name)
return self._remote_init_builder(initializer, name)
@property @property
def schemes(self): def schemes(self):
...@@ -337,15 +354,18 @@ class Frame(MutableMapping): ...@@ -337,15 +354,18 @@ class Frame(MutableMapping):
if name in self: if name in self:
dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name) dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name)
return return
if self.get_initializer(name) is None:
self._warn_and_set_initializer()
initializer = self.get_initializer(name)
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(0, self.num_rows))
# If the data is backed by a remote server, we need to move data # If the data is backed by a remote server, we need to move data
# to the remote server. # to the remote server.
if self._remote_initializer is not None: initializer = self.get_remote_initializer(name)
init_data = self._remote_initializer(name, init_data) if initializer is not None:
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype, ctx)
else:
if self.get_initializer(name) is None:
self._warn_and_set_initializer()
initializer = self.get_initializer(name)
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(0, self.num_rows))
self._columns[name] = Column(init_data, scheme) self._columns[name] = Column(init_data, scheme)
def add_rows(self, num_rows): def add_rows(self, num_rows):
...@@ -384,8 +404,11 @@ class Frame(MutableMapping): ...@@ -384,8 +404,11 @@ class Frame(MutableMapping):
""" """
# If the data is backed by a remote server, we need to move data # If the data is backed by a remote server, we need to move data
# to the remote server. # to the remote server.
if self._remote_initializer is not None: initializer = self.get_remote_initializer(name)
data = self._remote_initializer(name, data) if initializer is not None:
new_data = initializer(F.shape(data), F.dtype(data), F.context(data))
new_data[:] = data
data = new_data
col = Column.create(data) col = Column.create(data)
if len(col) != self.num_rows: if len(col) != self.num_rows:
raise DGLError('Expected data to have %d rows, got %d.' % raise DGLError('Expected data to have %d rows, got %d.' %
...@@ -393,7 +416,7 @@ class Frame(MutableMapping): ...@@ -393,7 +416,7 @@ class Frame(MutableMapping):
self._columns[name] = col self._columns[name] = col
def _append(self, other): def _append(self, other):
assert self._remote_initializer is None, \ assert self._remote_init_builder is None, \
"We don't support append if data in the frame is mapped from a remote server." "We don't support append if data in the frame is mapped from a remote server."
# NOTE: `other` can be empty. # NOTE: `other` can be empty.
if self.num_rows == 0: if self.num_rows == 0:
...@@ -512,17 +535,18 @@ class FrameRef(MutableMapping): ...@@ -512,17 +535,18 @@ class FrameRef(MutableMapping):
""" """
self._frame.set_initializer(initializer, column=column) self._frame.set_initializer(initializer, column=column)
def set_remote_initializer(self, initializer): def set_remote_init_builder(self, builder):
"""Set the remote initializer when a column is added to the frame. """Set an initializer builder to create a remote initializer for a new column to a frame.
Initializer is a callable that returns a tensor given a local tensor and tensor name. The builder is a callable that returns an initializer. The returned initializer
is also a callable that returns a tensor given a local tensor and tensor name.
Parameters Parameters
---------- ----------
initializer : callable builder : callable
The initializer. The builder to construct a remote initializer.
""" """
self._frame.set_remote_initializer(initializer) self._frame.set_remote_init_builder(builder)
def get_initializer(self, column=None): def get_initializer(self, column=None):
"""Get the initializer for empty values for the given column. """Get the initializer for empty values for the given column.
......
...@@ -37,7 +37,7 @@ SharedMemory::~SharedMemory() { ...@@ -37,7 +37,7 @@ SharedMemory::~SharedMemory() {
void *SharedMemory::create_new(size_t size) { void *SharedMemory::create_new(size_t size) {
this->own = true; this->own = true;
int flag = O_RDWR|O_EXCL|O_CREAT; int flag = O_RDWR|O_CREAT;
fd = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR); fd = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR);
CHECK_NE(fd, -1) << "fail to open " << name << ": " << strerror(errno); CHECK_NE(fd, -1) << "fail to open " << name << ": " << strerror(errno);
auto res = ftruncate(fd, size); auto res = ftruncate(fd, size);
......
import dgl import dgl
import sys
import random
import time import time
import numpy as np import numpy as np
from multiprocessing import Process from multiprocessing import Process
...@@ -6,17 +8,30 @@ from scipy import sparse as spsp ...@@ -6,17 +8,30 @@ from scipy import sparse as spsp
import mxnet as mx import mxnet as mx
import backend as F import backend as F
import unittest import unittest
import dgl.function as fn
num_nodes = 100 num_nodes = 100
num_edges = int(num_nodes * num_nodes * 0.1) num_edges = int(num_nodes * num_nodes * 0.1)
rand_port = random.randint(5000, 8000)
print('run graph store with port ' + str(rand_port), file=sys.stderr)
def worker_func(worker_id): def check_array_shared_memory(g, worker_id, arrays):
if worker_id == 0:
for i, arr in enumerate(arrays):
arr[0] = i
g._sync_barrier()
else:
g._sync_barrier()
for i, arr in enumerate(arrays):
assert np.all(arr[0].asnumpy() == i)
def check_init_func(worker_id, graph_name):
time.sleep(3) time.sleep(3)
print("worker starts") print("worker starts")
np.random.seed(0) np.random.seed(0)
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64) csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
g = dgl.contrib.graph_store.create_graph_from_store("test_graph5", "shared_mem") g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem", port=rand_port)
# Verify the graph structure loaded from the shared memory. # Verify the graph structure loaded from the shared memory.
src, dst = g.all_edges() src, dst = g.all_edges()
coo = csr.tocoo() coo = csr.tocoo()
...@@ -24,38 +39,55 @@ def worker_func(worker_id): ...@@ -24,38 +39,55 @@ def worker_func(worker_id):
assert F.array_equal(src, F.tensor(coo.col)) assert F.array_equal(src, F.tensor(coo.col))
assert F.array_equal(g.ndata['feat'][0], F.tensor(np.arange(10), dtype=np.float32)) assert F.array_equal(g.ndata['feat'][0], F.tensor(np.arange(10), dtype=np.float32))
assert F.array_equal(g.edata['feat'][0], F.tensor(np.arange(10), dtype=np.float32)) assert F.array_equal(g.edata['feat'][0], F.tensor(np.arange(10), dtype=np.float32))
g.ndata['test4'] = mx.nd.zeros((g.number_of_nodes(), 10)) g.init_ndata('test4', (g.number_of_nodes(), 10), 'float32')
g.edata['test4'] = mx.nd.zeros((g.number_of_edges(), 10)) g.init_edata('test4', (g.number_of_edges(), 10), 'float32')
if worker_id == 0: g._sync_barrier()
time.sleep(3) check_array_shared_memory(g, worker_id, [g.ndata['test4'], g.edata['test4']])
print(g.worker_id)
g.ndata['test4'][0] = 1
g.edata['test4'][0] = 2
else:
time.sleep(5)
print(g.worker_id)
assert np.all(g.ndata['test4'][0].asnumpy() == 1)
assert np.all(g.edata['test4'][0].asnumpy() == 2)
g.destroy() g.destroy()
def server_func(num_workers): def server_func(num_workers, graph_name):
print("server starts") print("server starts")
np.random.seed(0) np.random.seed(0)
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64) csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
g = dgl.contrib.graph_store.create_graph_store_server(csr, "test_graph5", "shared_mem", num_workers, g = dgl.contrib.graph_store.create_graph_store_server(csr, graph_name, "shared_mem", num_workers,
False, edge_dir="in") False, edge_dir="in", port=rand_port)
assert num_nodes == g._graph.number_of_nodes() assert num_nodes == g._graph.number_of_nodes()
assert num_edges == g._graph.number_of_edges() assert num_edges == g._graph.number_of_edges()
g.ndata['feat'] = mx.nd.arange(num_nodes * 10).reshape((num_nodes, 10)) g.ndata['feat'] = mx.nd.arange(num_nodes * 10).reshape((num_nodes, 10))
g.edata['feat'] = mx.nd.arange(num_edges * 10).reshape((num_edges, 10)) g.edata['feat'] = mx.nd.arange(num_edges * 10).reshape((num_edges, 10))
g.run() g.run()
@unittest.skip("disable shared memory test temporarily") def test_test_init():
def test_worker_server(): serv_p = Process(target=server_func, args=(2, 'test_graph1'))
serv_p = Process(target=server_func, args=(2,)) work_p1 = Process(target=check_init_func, args=(0, 'test_graph1'))
work_p1 = Process(target=worker_func, args=(0,)) work_p2 = Process(target=check_init_func, args=(1, 'test_graph1'))
work_p2 = Process(target=worker_func, args=(1,)) serv_p.start()
work_p1.start()
work_p2.start()
serv_p.join()
work_p1.join()
work_p2.join()
def check_update_all_func(worker_id, graph_name):
time.sleep(3)
print("worker starts")
g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem", port=rand_port)
g._sync_barrier()
g.dist_update_all(fn.copy_src(src='feat', out='m'),
fn.sum(msg='m', out='preprocess'))
adj = g.adjacency_matrix()
tmp = mx.nd.dot(adj, g.ndata['feat'])
assert np.all((g.ndata['preprocess'] == tmp).asnumpy())
g._sync_barrier()
check_array_shared_memory(g, worker_id, [g.ndata['preprocess']])
g.destroy()
def test_update_all():
serv_p = Process(target=server_func, args=(2, 'test_graph3'))
work_p1 = Process(target=check_update_all_func, args=(0, 'test_graph3'))
work_p2 = Process(target=check_update_all_func, args=(1, 'test_graph3'))
serv_p.start() serv_p.start()
work_p1.start() work_p1.start()
work_p2.start() work_p2.start()
...@@ -64,4 +96,5 @@ def test_worker_server(): ...@@ -64,4 +96,5 @@ def test_worker_server():
work_p2.join() work_p2.join()
if __name__ == '__main__': if __name__ == '__main__':
test_worker_server() test_test_init()
test_update_all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment