Unverified Commit 18dbaebe authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Specify the graph format for distributed training (#2948)



* explicitly set the graph format.

* fix.

* fix.

* fix launch script.

* fix readme.
Co-authored-by: default avatarZheng <dzzhen@3c22fba32af5.ant.amazon.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-71-112.ec2.internal>
parent 1db4ad4f
......@@ -135,6 +135,7 @@ python3 ~/workspace/dgl/tools/launch.py \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
--graph_format csc,coo \
"python3 train_dist_unsupervised.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 3 --batch_size 1000"
```
......@@ -183,6 +184,7 @@ python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pyt
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
--graph_format csc,coo \
"python3 train_dist_unsupervised_transductive.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 3 --batch_size 1000 --num_gpus 4"
```
......@@ -194,6 +196,7 @@ python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pyt
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
--graph_format csc,coo \
"python3 train_dist_unsupervised_transductive.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 3 --batch_size 1000 --num_gpus 4 --dgl_sparse"
```
......
......@@ -93,11 +93,14 @@ def initialize(ip_config, num_servers=1, num_workers=0,
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats]
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'))
os.environ.get('DGL_CONF_PATH'),
graph_format=formats)
serv.start()
sys.exit()
else:
......
......@@ -62,8 +62,8 @@ class InitGraphResponse(rpc.Response):
def __setstate__(self, state):
self._graph_name = state
def _copy_graph_to_shared_mem(g, graph_name):
new_g = g.shared_memory(graph_name, formats='csc')
def _copy_graph_to_shared_mem(g, graph_name, graph_format):
new_g = g.shared_memory(graph_name, formats=graph_format)
# We should share the node/edge data to the client explicitly instead of putting them
# in the KVStore because some of the node/edge data may be duplicated.
new_g.ndata['inner_node'] = _to_shared_mem(g.ndata['inner_node'],
......@@ -291,9 +291,12 @@ class DistGraphServer(KVServer):
The path of the config file generated by the partition tool.
disable_shared_mem : bool
Disable shared memory.
graph_format : str or list of str
The graph formats.
'''
def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False):
num_clients, part_config, disable_shared_mem=False,
graph_format='csc'):
super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config,
num_servers=num_servers,
......@@ -309,8 +312,11 @@ class DistGraphServer(KVServer):
self.client_g, node_feats, edge_feats, self.gpb, graph_name, \
ntypes, etypes = load_partition(part_config, self.part_id)
print('load ' + graph_name)
# Create the graph formats specified the users.
self.client_g = self.client_g.formats(graph_format)
self.client_g.create_formats_()
if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name)
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name, graph_format)
if not disable_shared_mem:
self.gpb.shared_memory(graph_name)
......
......@@ -16,9 +16,10 @@ from scipy import sparse as spsp
from dgl.distributed import DistGraphServer, DistGraph
def start_server(rank, tmpdir, disable_shared_mem, graph_name):
def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format='csc'):
g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem)
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem,
graph_format=graph_format)
g.start()
......@@ -119,7 +120,8 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_find_edges'))
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1,
'test_find_edges', ['csr', 'coo']))
p.start()
time.sleep(1)
pserver_list.append(p)
......
......@@ -167,12 +167,14 @@ def submit_jobs(args, udf_command):
server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config)
server_cmd = server_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config)
server_cmd = server_cmd + ' ' + 'DGL_NUM_SERVER=' + str(args.num_servers)
server_cmd = server_cmd + ' ' + 'DGL_GRAPH_FORMAT=' + str(args.graph_format)
for i in range(len(hosts)*server_count_per_machine):
ip, _ = hosts[int(i / server_count_per_machine)]
cmd = server_cmd + ' ' + 'DGL_SERVER_ID=' + str(i)
cmd = cmd + ' ' + udf_command
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
execute_remote(cmd, ip, args.ssh_port, thread_list)
# launch client tasks
client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client DGL_NUM_SAMPLER=' + str(args.num_samplers)
client_cmd = client_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(tot_num_clients)
......@@ -185,6 +187,7 @@ def submit_jobs(args, udf_command):
client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + str(args.num_omp_threads)
if os.environ.get('PYTHONPATH') is not None:
client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH')
client_cmd = client_cmd + ' ' + 'DGL_GRAPH_FORMAT=' + str(args.graph_format)
torch_cmd = '-m torch.distributed.launch'
torch_cmd = torch_cmd + ' ' + '--nproc_per_node=' + str(args.num_trainers)
......@@ -248,6 +251,10 @@ def main():
help='The number of OMP threads in the server process. \
It should be small if server processes and trainer processes run on \
the same machine. By default, it is 1.')
parser.add_argument('--graph_format', type=str, default='csc',
help='The format of the graph structure of each partition. \
The allowed formats are csr, csc and coo. A user can specify multiple \
formats, separated by ",". For example, the graph format is "csr,csc".')
args, udf_command = parser.parse_known_args()
assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_trainers is not None and args.num_trainers > 0, \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment