Unverified Commit 86229d42 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Automatically setting the number of OMP threads for trainers (#2812)

* set omp thread.

* add comment.

* fix.
parent cba5af22
...@@ -8,6 +8,7 @@ import signal ...@@ -8,6 +8,7 @@ import signal
import logging import logging
import time import time
import json import json
import multiprocessing
from threading import Thread from threading import Thread
DEFAULT_PORT = 30050 DEFAULT_PORT = 30050
...@@ -77,6 +78,8 @@ def submit_jobs(args, udf_command): ...@@ -77,6 +78,8 @@ def submit_jobs(args, udf_command):
client_cmd = client_cmd + ' ' + 'DGL_NUM_SERVER=' + str(args.num_servers) client_cmd = client_cmd + ' ' + 'DGL_NUM_SERVER=' + str(args.num_servers)
if os.environ.get('OMP_NUM_THREADS') is not None: if os.environ.get('OMP_NUM_THREADS') is not None:
client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + os.environ.get('OMP_NUM_THREADS') client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + os.environ.get('OMP_NUM_THREADS')
else:
client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + str(args.num_omp_threads)
if os.environ.get('PYTHONPATH') is not None: if os.environ.get('PYTHONPATH') is not None:
client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH') client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH')
...@@ -111,6 +114,8 @@ def main(): ...@@ -111,6 +114,8 @@ def main():
the contents of current directory will be rsyncd') the contents of current directory will be rsyncd')
parser.add_argument('--num_trainers', type=int, parser.add_argument('--num_trainers', type=int,
help='The number of trainer processes per machine') help='The number of trainer processes per machine')
parser.add_argument('--num_omp_threads', type=int,
help='The number of OMP threads per trainer')
parser.add_argument('--num_samplers', type=int, default=0, parser.add_argument('--num_samplers', type=int, default=0,
help='The number of sampler processes per trainer process') help='The number of sampler processes per trainer process')
parser.add_argument('--num_servers', type=int, parser.add_argument('--num_servers', type=int,
...@@ -137,6 +142,12 @@ def main(): ...@@ -137,6 +142,12 @@ def main():
'A user has to specify a partition configuration file with --part_config.' 'A user has to specify a partition configuration file with --part_config.'
assert args.ip_config is not None, \ assert args.ip_config is not None, \
'A user has to specify an IP configuration file with --ip_config.' 'A user has to specify an IP configuration file with --ip_config.'
if args.num_omp_threads is None:
# Here we assume all machines have the same number of CPU cores as the machine
# where the launch script runs.
args.num_omp_threads = max(multiprocessing.cpu_count() // 2 // args.num_trainers, 1)
print('The number of OMP threads per trainer is set to', args.num_omp_threads)
udf_command = str(udf_command[0]) udf_command = str(udf_command[0])
if 'python' not in udf_command: if 'python' not in udf_command:
raise RuntimeError("DGL launching script can only support Python executable file.") raise RuntimeError("DGL launching script can only support Python executable file.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment