Unverified Commit a1472bcf authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Run distributed graph server inside DGL (#1801)

* run dist server in dgl.

* fix bugs.

* fix example.

* check environment variables and fix lint.

* fix lint
parent be53add4
......@@ -28,15 +28,31 @@ will be able to access the partitioned data.
We need to run a server on each machine. Before running the servers, we need to update `ip_config.txt` with the right IP addresses.
On each of the machines, set the following environment variables.
```bash
export DGL_ROLE=server
export DGL_IP_CONFIG=ip_config.txt
export DGL_CONF_PATH=data/ogb-product.json
export DGL_NUM_CLIENT=4
```
```bash
# run server on machine 0
python3 train_dist.py --server --graph-name ogb-product --id 0 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt
export DGL_SERVER_ID=0
python3 train_dist.py
# run server on machine 1
python3 train_dist.py --server --graph-name ogb-product --id 1 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt
export DGL_SERVER_ID=1
python3 train_dist.py
# run server on machine 2
python3 train_dist.py --server --graph-name ogb-product --id 2 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt
export DGL_SERVER_ID=2
python3 train_dist.py
# run server on machine 3
python3 train_dist.py --server --graph-name ogb-product --id 3 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt
export DGL_SERVER_ID=3
python3 train_dist.py
```
### Step 4: run trainers
......@@ -45,11 +61,11 @@ Pytorch distributed requires one of the trainer process to be the master. Here w
```bash
# run client on machine 0
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=0 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=0 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
# run client on machine 1
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=1 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=1 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
# run client on machine 2
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=2 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=2 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
# run client on machine 3
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=3 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=3 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
```
......@@ -43,11 +43,6 @@ class NeighborSampler(object):
blocks.insert(0, block)
return blocks
def start_server(args):
serv = dgl.distributed.DistGraphServer(args.id, args.ip_config, args.num_client,
args.graph_name, args.conf_path)
serv.start()
def run(args, device, data):
# Unpack data
train_nid, val_nid, in_feats, n_classes, g = data
......@@ -181,8 +176,6 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument('--server', action='store_true',
help='whether this is a server.')
parser.add_argument('--graph-name', type=str, help='graph name')
parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
......@@ -206,8 +199,4 @@ if __name__ == '__main__':
args = parser.parse_args()
print(args)
if args.server:
start_server(args)
else:
main(args)
"""DGL distributed."""
import os
import sys
from .dist_graph import DistGraphServer, DistGraph, DistTensor, node_split, edge_split
from .partition import partition_graph, load_partition, load_partition_book
......@@ -11,3 +13,19 @@ from .rpc_client import connect_to_server, finalize_client, shutdown_servers
from .kvstore import KVServer, KVClient
from .server_state import ServerState
from .graph_services import sample_neighbors, in_subgraph
if os.environ.get('DGL_ROLE', 'client') == 'server':
assert os.environ.get('DGL_SERVER_ID') is not None, \
'Please define DGL_SERVER_ID to run DistGraph server'
assert os.environ.get('DGL_IP_CONFIG') is not None, \
'Please define DGL_IP_CONFIG to run DistGraph server'
assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
SERV = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'))
SERV.start()
sys.exit()
......@@ -284,20 +284,19 @@ class DistGraphServer(KVServer):
Path of IP configuration file.
num_clients : int
Total number of client nodes.
graph_name : string
The name of the graph. The server and the client need to specify the same graph name.
conf_file : string
The path of the config file generated by the partition tool.
disable_shared_mem : bool
Disable shared memory.
'''
def __init__(self, server_id, ip_config, num_clients, graph_name, conf_file,
disable_shared_mem=False):
def __init__(self, server_id, ip_config, num_clients, conf_file, disable_shared_mem=False):
super(DistGraphServer, self).__init__(server_id=server_id, ip_config=ip_config,
num_clients=num_clients)
self.ip_config = ip_config
# Load graph partition data.
self.client_g, node_feats, edge_feats, self.gpb = load_partition(conf_file, server_id)
self.client_g, node_feats, edge_feats, self.gpb, graph_name = load_partition(conf_file,
server_id)
print('load ' + graph_name)
if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name)
......@@ -326,6 +325,7 @@ class DistGraphServer(KVServer):
"""
# start server
server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb)
print('start graph service on server ' + str(self.server_id))
start_server(server_id=self.server_id, ip_config=self.ip_config,
num_clients=self.num_clients, server_state=server_state)
......
......@@ -119,6 +119,8 @@ def load_partition(conf_file, part_id):
All edge features.
GraphPartitionBook
The global partition information.
str
The graph name
'''
with open(conf_file) as conf_f:
part_metadata = json.load(conf_f)
......@@ -134,8 +136,8 @@ def load_partition(conf_file, part_id):
assert NID in graph.ndata, "the partition graph should contain node mapping to global node Id"
assert EID in graph.edata, "the partition graph should contain edge mapping to global edge Id"
gpb = load_partition_book(conf_file, part_id, graph)
return graph, node_feats, edge_feats, gpb
gpb, graph_name = load_partition_book(conf_file, part_id, graph)
return graph, node_feats, edge_feats, gpb, graph_name
def load_partition_book(conf_file, part_id, graph=None):
''' Load a graph partition book from the partition config file.
......@@ -153,6 +155,8 @@ def load_partition_book(conf_file, part_id, graph=None):
-------
GraphPartitionBook
The global partition information.
str
The graph name
'''
with open(conf_file) as conf_f:
part_metadata = json.load(conf_f)
......@@ -162,6 +166,7 @@ def load_partition_book(conf_file, part_id, graph=None):
assert 'num_edges' in part_metadata, "cannot get the number of edges of the global graph."
assert 'node_map' in part_metadata, "cannot get the node map."
assert 'edge_map' in part_metadata, "cannot get the edge map."
assert 'graph_name' in part_metadata, "cannot get the graph name"
# If this is a range partitioning, node_map actually stores a list, whose elements
# indicate the boundary of range partitioning. Otherwise, node_map stores a filename
......@@ -173,9 +178,11 @@ def load_partition_book(conf_file, part_id, graph=None):
"The node map and edge map need to have the same format"
if is_range_part:
return RangePartitionBook(part_id, num_parts, np.array(node_map), np.array(edge_map))
return RangePartitionBook(part_id, num_parts, np.array(node_map),
np.array(edge_map)), part_metadata['graph_name']
else:
return GraphPartitionBook(part_id, num_parts, node_map, edge_map, graph)
return GraphPartitionBook(part_id, num_parts, node_map, edge_map,
graph), part_metadata['graph_name']
def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method="metis",
reshuffle=True, balance_ntypes=None, balance_edges=False):
......
......@@ -55,7 +55,7 @@ def create_random_graph(n):
return dgl.DGLGraph(ig)
def run_server(graph_name, server_id, num_clients, shared_mem):
g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients, graph_name,
g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients,
'/tmp/dist_graph/{}.json'.format(graph_name),
disable_shared_mem=not shared_mem)
print('start server', server_id)
......@@ -66,7 +66,7 @@ def emb_init(shape, dtype):
def run_client(graph_name, part_id, num_nodes, num_edges):
time.sleep(5)
gpb = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
gpb, graph_name = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb)
......@@ -221,7 +221,7 @@ def test_split():
selected_edges = np.nonzero(edge_mask)[0]
for i in range(num_parts):
dgl.distributed.set_num_client(num_parts)
part_g, node_feats, edge_feats, gpb = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
part_g, node_feats, edge_feats, gpb, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
nodes1 = np.intersect1d(selected_nodes, F.asnumpy(local_nids))
......@@ -270,7 +270,7 @@ def test_split_even():
all_edges2 = []
for i in range(num_parts):
dgl.distributed.set_num_client(num_parts)
part_g, node_feats, edge_feats, gpb = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
part_g, node_feats, edge_feats, gpb, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
nodes = node_split(node_mask, gpb, i, force_even=True)
......
......@@ -17,7 +17,7 @@ from dgl.distributed import DistGraphServer, DistGraph
def start_server(rank, tmpdir, disable_shared_mem, graph_name):
import dgl
g = DistGraphServer(rank, "rpc_ip_config.txt", 1, graph_name,
g = DistGraphServer(rank, "rpc_ip_config.txt", 1,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem)
g.start()
......@@ -26,7 +26,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
import dgl
gpb = None
if disable_shared_mem:
_, _, _, gpb = load_partition(tmpdir / 'test_sampling.json', rank)
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dgl.distributed.shutdown_servers()
......@@ -108,7 +108,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64)
orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64)
for i in range(num_server):
part, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
part, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
......@@ -134,7 +134,7 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
import dgl
gpb = None
if disable_shared_mem:
_, _, _, gpb = load_partition(tmpdir / 'test_in_subgraph.json', rank)
_, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb)
sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
dgl.distributed.shutdown_servers()
......
......@@ -31,7 +31,7 @@ def check_partition(part_method, reshuffle):
part_method=part_method, reshuffle=reshuffle)
part_sizes = []
for i in range(num_parts):
part_g, node_feats, edge_feats, gpb = load_partition('/tmp/partition/test.json', i)
part_g, node_feats, edge_feats, gpb, _ = load_partition('/tmp/partition/test.json', i)
# Check the metadata
assert gpb._num_nodes() == g.number_of_nodes()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment