Unverified Commit ff5b5a4a authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] support no shared memory. (#1663)

* support no shared memory.

* add test.

* add CAPI to check existence of shared memory.

* revert the change in src/runtime/ndarray.cc

* update docstring.

* fix compile.
parent 0a4e8b32
"""DGL distributed.""" """DGL distributed."""
from .dist_graph import DistGraphServer, DistGraph, node_split, edge_split from .dist_graph import DistGraphServer, DistGraph, node_split, edge_split
from .partition import partition_graph, load_partition from .partition import partition_graph, load_partition, load_partition_book
from .graph_partition_book import GraphPartitionBook, PartitionPolicy from .graph_partition_book import GraphPartitionBook, PartitionPolicy
from .rpc import * from .rpc import *
......
...@@ -248,8 +248,9 @@ class DistGraphServer(KVServer): ...@@ -248,8 +248,9 @@ class DistGraphServer(KVServer):
graph partition. They all share the partition data (graph structure and node/edge data) with graph partition. They all share the partition data (graph structure and node/edge data) with
shared memory. shared memory.
In addition, the partition data is also shared with the DistGraph clients that run on By default, the partition data is shared with the DistGraph clients that run on
the same machine. the same machine. However, a user can disable shared memory option. This is useful for the case
that a user wants to run the server and the client on different machines.
Parameters Parameters
---------- ----------
...@@ -263,18 +264,23 @@ class DistGraphServer(KVServer): ...@@ -263,18 +264,23 @@ class DistGraphServer(KVServer):
The name of the graph. The server and the client need to specify the same graph name. The name of the graph. The server and the client need to specify the same graph name.
conf_file : string conf_file : string
The path of the config file generated by the partition tool. The path of the config file generated by the partition tool.
disable_shared_mem : bool
Disable shared memory.
''' '''
def __init__(self, server_id, ip_config, num_clients, graph_name, conf_file): def __init__(self, server_id, ip_config, num_clients, graph_name, conf_file,
disable_shared_mem=False):
super(DistGraphServer, self).__init__(server_id=server_id, ip_config=ip_config, super(DistGraphServer, self).__init__(server_id=server_id, ip_config=ip_config,
num_clients=num_clients) num_clients=num_clients)
self.ip_config = ip_config self.ip_config = ip_config
# Load graph partition data. # Load graph partition data.
self.client_g, node_feats, edge_feats, self.gpb = load_partition(conf_file, server_id) self.client_g, node_feats, edge_feats, self.gpb = load_partition(conf_file, server_id)
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name) if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name)
# Init kvstore. # Init kvstore.
self.gpb.shared_memory(graph_name) if not disable_shared_mem:
self.gpb.shared_memory(graph_name)
self.add_part_policy(PartitionPolicy('node', server_id, self.gpb)) self.add_part_policy(PartitionPolicy('node', server_id, self.gpb))
self.add_part_policy(PartitionPolicy('edge', server_id, self.gpb)) self.add_part_policy(PartitionPolicy('edge', server_id, self.gpb))
...@@ -306,8 +312,14 @@ class DistGraph: ...@@ -306,8 +312,14 @@ class DistGraph:
''' The DistGraph client. ''' The DistGraph client.
This provides the graph interface to access the partitioned graph data for distributed GNN This provides the graph interface to access the partitioned graph data for distributed GNN
training. All data of partitions are loaded by the DistGraph server. The client doesn't need training. All data of partitions are loaded by the DistGraph server.
to load any data.
By default, `DistGraph` uses shared-memory to access the partition data in the local machine.
This gives the best performance for distributed training when we run `DistGraphServer`
and `DistGraph` on the same machine. However, a user may want to run them in separate
machines. In this case, a user may want to disable shared memory by passing
`disable_shared_mem=False` when creating `DistGraphServer`. When shared-memory is disabled,
a user has to pass a partition book.
Parameters Parameters
---------- ----------
...@@ -315,13 +327,17 @@ class DistGraph: ...@@ -315,13 +327,17 @@ class DistGraph:
Path of IP configuration file. Path of IP configuration file.
graph_name : str graph_name : str
The name of the graph. This name has to be the same as the one used in DistGraphServer. The name of the graph. This name has to be the same as the one used in DistGraphServer.
gpb : PartitionBook
The partition book object
''' '''
def __init__(self, ip_config, graph_name): def __init__(self, ip_config, graph_name, gpb=None):
connect_to_server(ip_config=ip_config) connect_to_server(ip_config=ip_config)
self._client = KVClient(ip_config) self._client = KVClient(ip_config)
self._g = _get_graph_from_shared_mem(graph_name) self._g = _get_graph_from_shared_mem(graph_name)
self._gpb = get_shared_mem_partition_book(graph_name, self._g) self._gpb = get_shared_mem_partition_book(graph_name, self._g)
if self._gpb is None:
self._gpb = gpb
self._client.barrier() self._client.barrier()
self._client.map_shared_data(self._gpb) self._client.map_shared_data(self._gpb)
self._ndata = NodeDataView(self) self._ndata = NodeDataView(self)
......
...@@ -7,6 +7,7 @@ from ..base import NID, EID ...@@ -7,6 +7,7 @@ from ..base import NID, EID
from .. import utils from .. import utils
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from ..ndarray import exist_shared_mem_array
def _move_metadata_to_shared_mem(graph_name, num_nodes, num_edges, part_id, def _move_metadata_to_shared_mem(graph_name, num_nodes, num_edges, part_id,
num_partitions, node_map, edge_map, is_range_part): num_partitions, node_map, edge_map, is_range_part):
...@@ -70,6 +71,8 @@ def get_shared_mem_partition_book(graph_name, graph_part): ...@@ -70,6 +71,8 @@ def get_shared_mem_partition_book(graph_name, graph_part):
GraphPartitionBook or RangePartitionBook GraphPartitionBook or RangePartitionBook
A graph partition book for a particular partition. A graph partition book for a particular partition.
''' '''
if not exist_shared_mem_array(_get_ndata_path(graph_name, 'meta')):
return None
is_range_part, part_id, num_parts, node_map, edge_map = _get_shared_mem_metadata(graph_name) is_range_part, part_id, num_parts, node_map, edge_map = _get_shared_mem_metadata(graph_name)
if is_range_part == 1: if is_range_part == 1:
return RangePartitionBook(part_id, num_parts, node_map, edge_map) return RangePartitionBook(part_id, num_parts, node_map, edge_map)
......
...@@ -121,8 +121,6 @@ def load_partition(conf_file, part_id): ...@@ -121,8 +121,6 @@ def load_partition(conf_file, part_id):
''' '''
with open(conf_file) as conf_f: with open(conf_file) as conf_f:
part_metadata = json.load(conf_f) part_metadata = json.load(conf_f)
assert 'num_parts' in part_metadata, 'num_parts does not exist.'
num_parts = part_metadata['num_parts']
assert 'part-{}'.format(part_id) in part_metadata, "part-{} does not exist".format(part_id) assert 'part-{}'.format(part_id) in part_metadata, "part-{} does not exist".format(part_id)
part_files = part_metadata['part-{}'.format(part_id)] part_files = part_metadata['part-{}'.format(part_id)]
assert 'node_feats' in part_files, "the partition does not contain node features." assert 'node_feats' in part_files, "the partition does not contain node features."
...@@ -131,6 +129,34 @@ def load_partition(conf_file, part_id): ...@@ -131,6 +129,34 @@ def load_partition(conf_file, part_id):
node_feats = load_tensors(part_files['node_feats']) node_feats = load_tensors(part_files['node_feats'])
edge_feats = load_tensors(part_files['edge_feats']) edge_feats = load_tensors(part_files['edge_feats'])
graph = load_graphs(part_files['part_graph'])[0][0] graph = load_graphs(part_files['part_graph'])[0][0]
assert NID in graph.ndata, "the partition graph should contain node mapping to global node Id"
assert EID in graph.edata, "the partition graph should contain edge mapping to global edge Id"
gpb = load_partition_book(conf_file, part_id, graph)
return graph, node_feats, edge_feats, gpb
def load_partition_book(conf_file, part_id, graph=None):
''' Load a graph partition book from the partition config file.
Parameters
----------
conf_file : str
The path of the partition config file.
part_id : int
The partition Id.
graph : DGLGraph
The graph structure
Returns
-------
GraphPartitionBook
The global partition information.
'''
with open(conf_file) as conf_f:
part_metadata = json.load(conf_f)
assert 'num_parts' in part_metadata, 'num_parts does not exist.'
num_parts = part_metadata['num_parts']
assert 'num_nodes' in part_metadata, "cannot get the number of nodes of the global graph." assert 'num_nodes' in part_metadata, "cannot get the number of nodes of the global graph."
assert 'num_edges' in part_metadata, "cannot get the number of edges of the global graph." assert 'num_edges' in part_metadata, "cannot get the number of edges of the global graph."
assert 'node_map' in part_metadata, "cannot get the node map." assert 'node_map' in part_metadata, "cannot get the node map."
...@@ -145,14 +171,10 @@ def load_partition(conf_file, part_id): ...@@ -145,14 +171,10 @@ def load_partition(conf_file, part_id):
assert isinstance(node_map, list) == isinstance(edge_map, list), \ assert isinstance(node_map, list) == isinstance(edge_map, list), \
"The node map and edge map need to have the same format" "The node map and edge map need to have the same format"
assert NID in graph.ndata, "the partition graph should contain node mapping to global node Id"
assert EID in graph.edata, "the partition graph should contain edge mapping to global edge Id"
if is_range_part: if is_range_part:
gpb = RangePartitionBook(part_id, num_parts, np.array(node_map), np.array(edge_map)) return RangePartitionBook(part_id, num_parts, np.array(node_map), np.array(edge_map))
else: else:
gpb = GraphPartitionBook(part_id, num_parts, node_map, edge_map, graph) return GraphPartitionBook(part_id, num_parts, node_map, edge_map, graph)
return graph, node_feats, edge_feats, gpb
def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method="metis", def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method="metis",
reshuffle=True): reshuffle=True):
......
...@@ -90,6 +90,21 @@ def zerocopy_from_numpy(np_data): ...@@ -90,6 +90,21 @@ def zerocopy_from_numpy(np_data):
handle = ctypes.pointer(arr) handle = ctypes.pointer(arr)
return NDArray(handle, is_view=True) return NDArray(handle, is_view=True)
def exist_shared_mem_array(name):
""" Check the existence of shared-memory array.
Parameters
----------
name : str
The name of the shared-memory array.
Returns
-------
bool
The existence of the array
"""
return _CAPI_DGLExistSharedMemArray(name)
class SparseFormat: class SparseFormat:
"""Format code""" """Format code"""
ANY = 0 ANY = 0
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/shared_mem.h>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./array_op.h" #include "./array_op.h"
#include "./arith.h" #include "./arith.h"
...@@ -702,5 +703,15 @@ DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLCreateSparseMatrix") ...@@ -702,5 +703,15 @@ DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLCreateSparseMatrix")
*rv = SparseMatrixRef(spmat); *rv = SparseMatrixRef(spmat);
}); });
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const std::string name = args[0];
#ifndef _WIN32
*rv = SharedMemory::Exist(name);
#else
*rv = false;
#endif // _WIN32
});
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -12,7 +12,7 @@ import multiprocessing as mp ...@@ -12,7 +12,7 @@ import multiprocessing as mp
from dgl.graph_index import create_graph_index from dgl.graph_index import create_graph_index
from dgl.data.utils import load_graphs, save_graphs from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import DistGraphServer, DistGraph from dgl.distributed import DistGraphServer, DistGraph
from dgl.distributed import partition_graph, load_partition, GraphPartitionBook, node_split, edge_split from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split
import backend as F import backend as F
import unittest import unittest
import pickle import pickle
...@@ -51,14 +51,17 @@ def create_random_graph(n): ...@@ -51,14 +51,17 @@ def create_random_graph(n):
ig = create_graph_index(arr, readonly=True) ig = create_graph_index(arr, readonly=True)
return dgl.DGLGraph(ig) return dgl.DGLGraph(ig)
def run_server(graph_name, server_id, num_clients): def run_server(graph_name, server_id, num_clients, shared_mem):
g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients, graph_name, g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients, graph_name,
'/tmp/dist_graph/{}.json'.format(graph_name)) '/tmp/dist_graph/{}.json'.format(graph_name),
disable_shared_mem=not shared_mem)
print('start server', server_id) print('start server', server_id)
g.start() g.start()
def run_client(graph_name, num_nodes, num_edges): def run_client(graph_name, part_id, num_nodes, num_edges):
g = DistGraph("kv_ip_config.txt", graph_name) gpb = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb)
# Test API # Test API
assert g.number_of_nodes() == num_nodes assert g.number_of_nodes() == num_nodes
...@@ -116,8 +119,7 @@ def run_client(graph_name, num_nodes, num_edges): ...@@ -116,8 +119,7 @@ def run_client(graph_name, num_nodes, num_edges):
dgl.distributed.finalize_client() dgl.distributed.finalize_client()
print('end') print('end')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph") def check_server_client(shared_mem):
def test_server_client():
prepare_dist() prepare_dist()
g = create_random_graph(10000) g = create_random_graph(10000)
...@@ -133,14 +135,14 @@ def test_server_client(): ...@@ -133,14 +135,14 @@ def test_server_client():
serv_ps = [] serv_ps = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
for serv_id in range(1): for serv_id in range(1):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, 1)) p = ctx.Process(target=run_server, args=(graph_name, serv_id, 1, shared_mem))
serv_ps.append(p) serv_ps.append(p)
p.start() p.start()
cli_ps = [] cli_ps = []
for cli_id in range(1): for cli_id in range(1):
print('start client', cli_id) print('start client', cli_id)
p = ctx.Process(target=run_client, args=(graph_name, g.number_of_nodes(), p = ctx.Process(target=run_client, args=(graph_name, cli_id, g.number_of_nodes(),
g.number_of_edges())) g.number_of_edges()))
p.start() p.start()
cli_ps.append(p) cli_ps.append(p)
...@@ -153,6 +155,11 @@ def test_server_client(): ...@@ -153,6 +155,11 @@ def test_server_client():
print('clients have terminated') print('clients have terminated')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_server_client():
check_server_client(True)
check_server_client(False)
def test_split(): def test_split():
prepare_dist() prepare_dist()
g = create_random_graph(10000) g = create_random_graph(10000)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment