Unverified Commit a1472bcf authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Run distributed graph server inside DGL (#1801)

* run dist server in dgl.

* fix bugs.

* fix example.

* check environment variables and fix lint.

* fix lint
parent be53add4
...@@ -28,15 +28,31 @@ will be able to access the partitioned data. ...@@ -28,15 +28,31 @@ will be able to access the partitioned data.
We need to run a server on each machine. Before running the servers, we need to update `ip_config.txt` with the right IP addresses. We need to run a server on each machine. Before running the servers, we need to update `ip_config.txt` with the right IP addresses.
On each of the machines, set the following environment variables.
```bash
export DGL_ROLE=server
export DGL_IP_CONFIG=ip_config.txt
export DGL_CONF_PATH=data/ogb-product.json
export DGL_NUM_CLIENT=4
```
```bash ```bash
# run server on machine 0 # run server on machine 0
python3 train_dist.py --server --graph-name ogb-product --id 0 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt export DGL_SERVER_ID=0
python3 train_dist.py
# run server on machine 1 # run server on machine 1
python3 train_dist.py --server --graph-name ogb-product --id 1 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt export DGL_SERVER_ID=1
python3 train_dist.py
# run server on machine 2 # run server on machine 2
python3 train_dist.py --server --graph-name ogb-product --id 2 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt export DGL_SERVER_ID=2
python3 train_dist.py
# run server on machine 3 # run server on machine 3
python3 train_dist.py --server --graph-name ogb-product --id 3 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt export DGL_SERVER_ID=3
python3 train_dist.py
``` ```
### Step 4: run trainers ### Step 4: run trainers
...@@ -45,11 +61,11 @@ Pytorch distributed requires one of the trainer process to be the master. Here w ...@@ -45,11 +61,11 @@ Pytorch distributed requires one of the trainer process to be the master. Here w
```bash ```bash
# run client on machine 0 # run client on machine 0
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=0 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=0 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
# run client on machine 1 # run client on machine 1
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=1 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=1 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
# run client on machine 2 # run client on machine 2
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=2 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=2 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
# run client on machine 3 # run client on machine 3
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=3 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=3 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
``` ```
...@@ -43,11 +43,6 @@ class NeighborSampler(object): ...@@ -43,11 +43,6 @@ class NeighborSampler(object):
blocks.insert(0, block) blocks.insert(0, block)
return blocks return blocks
def start_server(args):
serv = dgl.distributed.DistGraphServer(args.id, args.ip_config, args.num_client,
args.graph_name, args.conf_path)
serv.start()
def run(args, device, data): def run(args, device, data):
# Unpack data # Unpack data
train_nid, val_nid, in_feats, n_classes, g = data train_nid, val_nid, in_feats, n_classes, g = data
...@@ -181,8 +176,6 @@ def main(args): ...@@ -181,8 +176,6 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN') parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser) register_data_args(parser)
parser.add_argument('--server', action='store_true',
help='whether this is a server.')
parser.add_argument('--graph-name', type=str, help='graph name') parser.add_argument('--graph-name', type=str, help='graph name')
parser.add_argument('--id', type=int, help='the partition id') parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip_config', type=str, help='The file for IP configuration') parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
...@@ -206,8 +199,4 @@ if __name__ == '__main__': ...@@ -206,8 +199,4 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.server:
start_server(args)
else:
main(args) main(args)
"""DGL distributed.""" """DGL distributed."""
import os
import sys
from .dist_graph import DistGraphServer, DistGraph, DistTensor, node_split, edge_split from .dist_graph import DistGraphServer, DistGraph, DistTensor, node_split, edge_split
from .partition import partition_graph, load_partition, load_partition_book from .partition import partition_graph, load_partition, load_partition_book
...@@ -11,3 +13,19 @@ from .rpc_client import connect_to_server, finalize_client, shutdown_servers ...@@ -11,3 +13,19 @@ from .rpc_client import connect_to_server, finalize_client, shutdown_servers
from .kvstore import KVServer, KVClient from .kvstore import KVServer, KVClient
from .server_state import ServerState from .server_state import ServerState
from .graph_services import sample_neighbors, in_subgraph from .graph_services import sample_neighbors, in_subgraph
if os.environ.get('DGL_ROLE', 'client') == 'server':
assert os.environ.get('DGL_SERVER_ID') is not None, \
'Please define DGL_SERVER_ID to run DistGraph server'
assert os.environ.get('DGL_IP_CONFIG') is not None, \
'Please define DGL_IP_CONFIG to run DistGraph server'
assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
SERV = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'))
SERV.start()
sys.exit()
...@@ -284,20 +284,19 @@ class DistGraphServer(KVServer): ...@@ -284,20 +284,19 @@ class DistGraphServer(KVServer):
Path of IP configuration file. Path of IP configuration file.
num_clients : int num_clients : int
Total number of client nodes. Total number of client nodes.
graph_name : string
The name of the graph. The server and the client need to specify the same graph name.
conf_file : string conf_file : string
The path of the config file generated by the partition tool. The path of the config file generated by the partition tool.
disable_shared_mem : bool disable_shared_mem : bool
Disable shared memory. Disable shared memory.
''' '''
def __init__(self, server_id, ip_config, num_clients, graph_name, conf_file, def __init__(self, server_id, ip_config, num_clients, conf_file, disable_shared_mem=False):
disable_shared_mem=False):
super(DistGraphServer, self).__init__(server_id=server_id, ip_config=ip_config, super(DistGraphServer, self).__init__(server_id=server_id, ip_config=ip_config,
num_clients=num_clients) num_clients=num_clients)
self.ip_config = ip_config self.ip_config = ip_config
# Load graph partition data. # Load graph partition data.
self.client_g, node_feats, edge_feats, self.gpb = load_partition(conf_file, server_id) self.client_g, node_feats, edge_feats, self.gpb, graph_name = load_partition(conf_file,
server_id)
print('load ' + graph_name)
if not disable_shared_mem: if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name) self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name)
...@@ -326,6 +325,7 @@ class DistGraphServer(KVServer): ...@@ -326,6 +325,7 @@ class DistGraphServer(KVServer):
""" """
# start server # start server
server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb) server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb)
print('start graph service on server ' + str(self.server_id))
start_server(server_id=self.server_id, ip_config=self.ip_config, start_server(server_id=self.server_id, ip_config=self.ip_config,
num_clients=self.num_clients, server_state=server_state) num_clients=self.num_clients, server_state=server_state)
......
...@@ -119,6 +119,8 @@ def load_partition(conf_file, part_id): ...@@ -119,6 +119,8 @@ def load_partition(conf_file, part_id):
All edge features. All edge features.
GraphPartitionBook GraphPartitionBook
The global partition information. The global partition information.
str
The graph name
''' '''
with open(conf_file) as conf_f: with open(conf_file) as conf_f:
part_metadata = json.load(conf_f) part_metadata = json.load(conf_f)
...@@ -134,8 +136,8 @@ def load_partition(conf_file, part_id): ...@@ -134,8 +136,8 @@ def load_partition(conf_file, part_id):
assert NID in graph.ndata, "the partition graph should contain node mapping to global node Id" assert NID in graph.ndata, "the partition graph should contain node mapping to global node Id"
assert EID in graph.edata, "the partition graph should contain edge mapping to global edge Id" assert EID in graph.edata, "the partition graph should contain edge mapping to global edge Id"
gpb = load_partition_book(conf_file, part_id, graph) gpb, graph_name = load_partition_book(conf_file, part_id, graph)
return graph, node_feats, edge_feats, gpb return graph, node_feats, edge_feats, gpb, graph_name
def load_partition_book(conf_file, part_id, graph=None): def load_partition_book(conf_file, part_id, graph=None):
''' Load a graph partition book from the partition config file. ''' Load a graph partition book from the partition config file.
...@@ -153,6 +155,8 @@ def load_partition_book(conf_file, part_id, graph=None): ...@@ -153,6 +155,8 @@ def load_partition_book(conf_file, part_id, graph=None):
------- -------
GraphPartitionBook GraphPartitionBook
The global partition information. The global partition information.
str
The graph name
''' '''
with open(conf_file) as conf_f: with open(conf_file) as conf_f:
part_metadata = json.load(conf_f) part_metadata = json.load(conf_f)
...@@ -162,6 +166,7 @@ def load_partition_book(conf_file, part_id, graph=None): ...@@ -162,6 +166,7 @@ def load_partition_book(conf_file, part_id, graph=None):
assert 'num_edges' in part_metadata, "cannot get the number of edges of the global graph." assert 'num_edges' in part_metadata, "cannot get the number of edges of the global graph."
assert 'node_map' in part_metadata, "cannot get the node map." assert 'node_map' in part_metadata, "cannot get the node map."
assert 'edge_map' in part_metadata, "cannot get the edge map." assert 'edge_map' in part_metadata, "cannot get the edge map."
assert 'graph_name' in part_metadata, "cannot get the graph name"
# If this is a range partitioning, node_map actually stores a list, whose elements # If this is a range partitioning, node_map actually stores a list, whose elements
# indicate the boundary of range partitioning. Otherwise, node_map stores a filename # indicate the boundary of range partitioning. Otherwise, node_map stores a filename
...@@ -173,9 +178,11 @@ def load_partition_book(conf_file, part_id, graph=None): ...@@ -173,9 +178,11 @@ def load_partition_book(conf_file, part_id, graph=None):
"The node map and edge map need to have the same format" "The node map and edge map need to have the same format"
if is_range_part: if is_range_part:
return RangePartitionBook(part_id, num_parts, np.array(node_map), np.array(edge_map)) return RangePartitionBook(part_id, num_parts, np.array(node_map),
np.array(edge_map)), part_metadata['graph_name']
else: else:
return GraphPartitionBook(part_id, num_parts, node_map, edge_map, graph) return GraphPartitionBook(part_id, num_parts, node_map, edge_map,
graph), part_metadata['graph_name']
def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method="metis", def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method="metis",
reshuffle=True, balance_ntypes=None, balance_edges=False): reshuffle=True, balance_ntypes=None, balance_edges=False):
......
...@@ -55,7 +55,7 @@ def create_random_graph(n): ...@@ -55,7 +55,7 @@ def create_random_graph(n):
return dgl.DGLGraph(ig) return dgl.DGLGraph(ig)
def run_server(graph_name, server_id, num_clients, shared_mem): def run_server(graph_name, server_id, num_clients, shared_mem):
g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients, graph_name, g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients,
'/tmp/dist_graph/{}.json'.format(graph_name), '/tmp/dist_graph/{}.json'.format(graph_name),
disable_shared_mem=not shared_mem) disable_shared_mem=not shared_mem)
print('start server', server_id) print('start server', server_id)
...@@ -66,7 +66,7 @@ def emb_init(shape, dtype): ...@@ -66,7 +66,7 @@ def emb_init(shape, dtype):
def run_client(graph_name, part_id, num_nodes, num_edges): def run_client(graph_name, part_id, num_nodes, num_edges):
time.sleep(5) time.sleep(5)
gpb = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb, graph_name = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None) part_id, None)
g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb) g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb)
...@@ -221,7 +221,7 @@ def test_split(): ...@@ -221,7 +221,7 @@ def test_split():
selected_edges = np.nonzero(edge_mask)[0] selected_edges = np.nonzero(edge_mask)[0]
for i in range(num_parts): for i in range(num_parts):
dgl.distributed.set_num_client(num_parts) dgl.distributed.set_num_client(num_parts)
part_g, node_feats, edge_feats, gpb = load_partition('/tmp/dist_graph/dist_graph_test.json', i) part_g, node_feats, edge_feats, gpb, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
local_nids = F.nonzero_1d(part_g.ndata['inner_node']) local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids) local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
nodes1 = np.intersect1d(selected_nodes, F.asnumpy(local_nids)) nodes1 = np.intersect1d(selected_nodes, F.asnumpy(local_nids))
...@@ -270,7 +270,7 @@ def test_split_even(): ...@@ -270,7 +270,7 @@ def test_split_even():
all_edges2 = [] all_edges2 = []
for i in range(num_parts): for i in range(num_parts):
dgl.distributed.set_num_client(num_parts) dgl.distributed.set_num_client(num_parts)
part_g, node_feats, edge_feats, gpb = load_partition('/tmp/dist_graph/dist_graph_test.json', i) part_g, node_feats, edge_feats, gpb, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
local_nids = F.nonzero_1d(part_g.ndata['inner_node']) local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids) local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
nodes = node_split(node_mask, gpb, i, force_even=True) nodes = node_split(node_mask, gpb, i, force_even=True)
......
...@@ -17,7 +17,7 @@ from dgl.distributed import DistGraphServer, DistGraph ...@@ -17,7 +17,7 @@ from dgl.distributed import DistGraphServer, DistGraph
def start_server(rank, tmpdir, disable_shared_mem, graph_name): def start_server(rank, tmpdir, disable_shared_mem, graph_name):
import dgl import dgl
g = DistGraphServer(rank, "rpc_ip_config.txt", 1, graph_name, g = DistGraphServer(rank, "rpc_ip_config.txt", 1,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem) tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem)
g.start() g.start()
...@@ -26,7 +26,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem): ...@@ -26,7 +26,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
import dgl import dgl
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb = load_partition(tmpdir / 'test_sampling.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb) dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dgl.distributed.shutdown_servers() dgl.distributed.shutdown_servers()
...@@ -108,7 +108,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server): ...@@ -108,7 +108,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64) orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64)
orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64) orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64)
for i in range(num_server): for i in range(num_server):
part, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i) part, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id'] orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id'] orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
...@@ -134,7 +134,7 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes): ...@@ -134,7 +134,7 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
import dgl import dgl
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb = load_partition(tmpdir / 'test_in_subgraph.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb) dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb)
sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes) sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
dgl.distributed.shutdown_servers() dgl.distributed.shutdown_servers()
......
...@@ -31,7 +31,7 @@ def check_partition(part_method, reshuffle): ...@@ -31,7 +31,7 @@ def check_partition(part_method, reshuffle):
part_method=part_method, reshuffle=reshuffle) part_method=part_method, reshuffle=reshuffle)
part_sizes = [] part_sizes = []
for i in range(num_parts): for i in range(num_parts):
part_g, node_feats, edge_feats, gpb = load_partition('/tmp/partition/test.json', i) part_g, node_feats, edge_feats, gpb, _ = load_partition('/tmp/partition/test.json', i)
# Check the metadata # Check the metadata
assert gpb._num_nodes() == g.number_of_nodes() assert gpb._num_nodes() == g.number_of_nodes()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment