Unverified Commit b258729b authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] set socket as default backend for RPC (#4120)

* [Dist] set socket as default backend for RPC

* add tests both for socket and tensorpipe
parent 702d08db
...@@ -174,7 +174,7 @@ class CustomPool: ...@@ -174,7 +174,7 @@ class CustomPool:
def initialize(ip_config, num_servers=1, num_workers=0, def initialize(ip_config, num_servers=1, num_workers=0,
max_queue_size=MAX_QUEUE_SIZE, net_type='tensorpipe', max_queue_size=MAX_QUEUE_SIZE, net_type='socket',
num_worker_threads=1): num_worker_threads=1):
"""Initialize DGL's distributed module """Initialize DGL's distributed module
...@@ -203,7 +203,7 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -203,7 +203,7 @@ def initialize(ip_config, num_servers=1, num_workers=0,
net_type : str, optional net_type : str, optional
Networking type. Valid options are: ``'socket'``, ``'tensorpipe'``. Networking type. Valid options are: ``'socket'``, ``'tensorpipe'``.
Default: ``'tensorpipe'`` Default: ``'socket'``
num_worker_threads: int num_worker_threads: int
The number of threads in a worker process. The number of threads in a worker process.
......
...@@ -317,7 +317,7 @@ class DistGraphServer(KVServer): ...@@ -317,7 +317,7 @@ class DistGraphServer(KVServer):
def __init__(self, server_id, ip_config, num_servers, def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False, num_clients, part_config, disable_shared_mem=False,
graph_format=('csc', 'coo'), keep_alive=False, graph_format=('csc', 'coo'), keep_alive=False,
net_type='tensorpipe'): net_type='socket'):
super(DistGraphServer, self).__init__(server_id=server_id, super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config, ip_config=ip_config,
num_servers=num_servers, num_servers=num_servers,
......
...@@ -103,7 +103,7 @@ def get_local_usable_addr(probe_addr): ...@@ -103,7 +103,7 @@ def get_local_usable_addr(probe_addr):
def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
net_type='tensorpipe', group_id=0): net_type='socket', group_id=0):
"""Connect this client to server. """Connect this client to server.
Parameters Parameters
......
...@@ -7,7 +7,7 @@ from . import rpc ...@@ -7,7 +7,7 @@ from . import rpc
from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE
def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
max_queue_size=MAX_QUEUE_SIZE, net_type='tensorpipe'): max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Start DGL server, which will be shared with all the rpc services. """Start DGL server, which will be shared with all the rpc services.
This is a blocking function -- it returns only when the server shutdown. This is a blocking function -- it returns only when the server shutdown.
......
...@@ -271,13 +271,14 @@ def test_rpc_msg(): ...@@ -271,13 +271,14 @@ def test_rpc_msg():
assert F.array_equal(rpcmsg.tensors[0], req.z) assert F.array_equal(rpcmsg.tensors[0], req.z)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc(): @pytest.mark.parametrize("net_type", ['tensorpipe'])
def test_rpc(net_type):
reset_envs() reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
generate_ip_config("rpc_ip_config.txt", 1, 1) generate_ip_config("rpc_ip_config.txt", 1, 1)
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt")) pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt", 0, False, 1, net_type))
pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt",)) pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt", 0, 1, net_type))
pserver.start() pserver.start()
pclient.start() pclient.start()
pserver.join() pserver.join()
...@@ -306,20 +307,22 @@ def test_multi_client(net_type): ...@@ -306,20 +307,22 @@ def test_multi_client(net_type):
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_thread_rpc(): @pytest.mark.parametrize("net_type", ['socket', 'tensorpipe'])
def test_multi_thread_rpc(net_type):
reset_envs() reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
num_servers = 2 num_servers = 2
generate_ip_config("rpc_ip_config_multithread.txt", num_servers, num_servers) ip_config = "rpc_ip_config_multithread.txt"
generate_ip_config(ip_config, num_servers, num_servers)
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
pserver_list = [] pserver_list = []
for i in range(num_servers): for i in range(num_servers):
pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config_multithread.txt", i)) pserver = ctx.Process(target=start_server, args=(1, ip_config, i, False, 1, net_type))
pserver.start() pserver.start()
pserver_list.append(pserver) pserver_list.append(pserver)
def start_client_multithread(ip_config): def start_client_multithread(ip_config):
import threading import threading
dgl.distributed.connect_to_server(ip_config=ip_config, num_servers=1) dgl.distributed.connect_to_server(ip_config=ip_config, num_servers=1, net_type=net_type)
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
req = HelloRequest(STR, INTEGER, TENSOR, simple_func) req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
...@@ -341,7 +344,7 @@ def test_multi_thread_rpc(): ...@@ -341,7 +344,7 @@ def test_multi_thread_rpc():
assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR)) assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR))
dgl.distributed.exit_client() dgl.distributed.exit_client()
start_client_multithread("rpc_ip_config_multithread.txt") start_client_multithread(ip_config)
pserver.join() pserver.join()
@unittest.skipIf(True, reason="Tests of multiple groups may fail and let's disable them for now.") @unittest.skipIf(True, reason="Tests of multiple groups may fail and let's disable them for now.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment