Unverified Commit b9ef70e5 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Set the number of threads correctly to speed up (#1976)

* temp fix omp.

* set server threads.

* add CAPI to set up OMP threads.

* fix.

* fix.

* update namesapce.

* set cpi properly.

* allow to config num worker threads.

* set #threads.

* fix.
parent f0fbbc16
......@@ -9,6 +9,7 @@ from .constants import MAX_QUEUE_SIZE
from .kvstore import init_kvstore, close_kvstore
from .rpc_client import connect_to_server, shutdown_servers
from .role import init_role
from .. import utils
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = 0
......@@ -25,10 +26,11 @@ def get_sampler_pool():
return SAMPLER_POOL, NUM_SAMPLER_WORKERS
def _init_rpc(ip_config, max_queue_size, net_type, role):
def _init_rpc(ip_config, max_queue_size, net_type, role, num_threads):
''' This init function is called in the worker processes.
'''
try:
utils.set_num_threads(num_threads)
connect_to_server(ip_config, max_queue_size, net_type)
init_role(role)
init_kvstore(ip_config, role)
......@@ -38,7 +40,8 @@ def _init_rpc(ip_config, max_queue_size, net_type, role):
raise e
def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type='socket',
num_worker_threads=1):
"""Init rpc service
ip_config: str
File path of ip_config file
......@@ -50,6 +53,8 @@ def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type
it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: 'socket'.
num_worker_threads: int
The number of threads in a worker process.
"""
rpc.reset()
ctx = mp.get_context("spawn")
......@@ -58,7 +63,7 @@ def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type
if num_workers > 0:
SAMPLER_POOL = ctx.Pool(
num_workers, initializer=_init_rpc, initargs=(ip_config, max_queue_size,
net_type, 'sampler'))
net_type, 'sampler', num_worker_threads))
NUM_SAMPLER_WORKERS = num_workers
connect_to_server(ip_config, max_queue_size, net_type)
init_role('default')
......
......@@ -9,6 +9,7 @@ import numpy as np
from ..base import DGLError, dgl_warning, NID, EID
from .. import backend as F
from .. import ndarray as nd
from .._ffi.function import _init_api
class InconsistentDtypeException(DGLError):
"""Exception class for inconsistent dtype between graph and tensor"""
......@@ -793,3 +794,15 @@ def extract_subframes(graph, nodes, edges):
subf[EID] = ind_edges
edge_frames.append(subf)
return node_frames, edge_frames
def set_num_threads(num_threads):
"""Set the number of OMP threads in the process.
Parameters
----------
num_threads : int
The number of OMP threads in the process.
"""
_CAPI_DGLSetOMPThreads(num_threads)
_init_api("dgl.utils.internal")
/*!
* Copyright (c) 2020 by Contributors
* \file utils.cc
* \brief DGL util functions
*/
#include <omp.h>
#include <dgl/packed_func_ext.h>
#include "../c_api_common.h"
using namespace dgl::runtime;
namespace dgl {
DGL_REGISTER_GLOBAL("utils.internal._CAPI_DGLSetOMPThreads")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int num_threads = args[0];
omp_set_num_threads(num_threads);
});
} // namespace dgl
......@@ -50,6 +50,7 @@ def submit_jobs(args, udf_command):
tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts)
# launch server tasks
server_cmd = 'DGL_ROLE=server'
server_cmd = server_cmd + ' ' + 'OMP_NUM_THREADS=' + str(args.num_server_threads)
server_cmd = server_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(tot_num_clients)
server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config)
server_cmd = server_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config)
......@@ -105,6 +106,10 @@ def main():
help='The file (in workspace) of the partition config')
parser.add_argument('--ip_config', type=str,
help='The file (in workspace) of IP configuration for server processes')
parser.add_argument('--num_server_threads', type=int, default=1,
help='The number of OMP threads in the server process. \
It should be small if server processes and trainer processes run on \
the same machine. By default, it is 1.')
args, udf_command = parser.parse_known_args()
assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_trainers > 0, '--num_trainers must be a positive number.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment