Unverified Commit 0f322331 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistDGL] remove tensorpipe from python APIs (#5848)

parent e47a0279
...@@ -39,7 +39,6 @@ def _init_rpc( ...@@ -39,7 +39,6 @@ def _init_rpc(
ip_config, ip_config,
num_servers, num_servers,
max_queue_size, max_queue_size,
net_type,
role, role,
num_threads, num_threads,
group_id, group_id,
...@@ -48,9 +47,7 @@ def _init_rpc( ...@@ -48,9 +47,7 @@ def _init_rpc(
try: try:
utils.set_num_threads(num_threads) utils.set_num_threads(num_threads)
if os.environ.get("DGL_DIST_MODE", "standalone") != "standalone": if os.environ.get("DGL_DIST_MODE", "standalone") != "standalone":
connect_to_server( connect_to_server(ip_config, num_servers, max_queue_size, group_id)
ip_config, num_servers, max_queue_size, net_type, group_id
)
init_role(role) init_role(role)
init_kvstore(ip_config, num_servers, role) init_kvstore(ip_config, num_servers, role)
except Exception as e: except Exception as e:
...@@ -211,7 +208,6 @@ class CustomPool: ...@@ -211,7 +208,6 @@ class CustomPool:
def initialize( def initialize(
ip_config, ip_config,
max_queue_size=MAX_QUEUE_SIZE, max_queue_size=MAX_QUEUE_SIZE,
net_type="socket",
num_worker_threads=1, num_worker_threads=1,
): ):
"""Initialize DGL's distributed module """Initialize DGL's distributed module
...@@ -230,10 +226,6 @@ def initialize( ...@@ -230,10 +226,6 @@ def initialize(
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
it will not allocate 20GB memory at once. it will not allocate 20GB memory at once.
net_type : str, optional
Networking type. Valid options are: ``'socket'``, ``'tensorpipe'``.
Default: ``'socket'``
num_worker_threads: int num_worker_threads: int
The number of OMP threads in each sampler process. The number of OMP threads in each sampler process.
...@@ -273,7 +265,6 @@ def initialize( ...@@ -273,7 +265,6 @@ def initialize(
os.environ.get("DGL_CONF_PATH"), os.environ.get("DGL_CONF_PATH"),
graph_format=formats, graph_format=formats,
keep_alive=keep_alive, keep_alive=keep_alive,
net_type=net_type,
) )
serv.start() serv.start()
sys.exit() sys.exit()
...@@ -294,7 +285,6 @@ def initialize( ...@@ -294,7 +285,6 @@ def initialize(
ip_config, ip_config,
num_servers, num_servers,
max_queue_size, max_queue_size,
net_type,
"sampler", "sampler",
num_worker_threads, num_worker_threads,
group_id, group_id,
...@@ -311,7 +301,6 @@ def initialize( ...@@ -311,7 +301,6 @@ def initialize(
ip_config, ip_config,
num_servers, num_servers,
max_queue_size, max_queue_size,
net_type,
group_id=group_id, group_id=group_id,
) )
init_role("default") init_role("default")
......
...@@ -332,8 +332,6 @@ class DistGraphServer(KVServer): ...@@ -332,8 +332,6 @@ class DistGraphServer(KVServer):
The graph formats. The graph formats.
keep_alive : bool keep_alive : bool
Whether to keep server alive when clients exit Whether to keep server alive when clients exit
net_type : str
Backend rpc type: ``'socket'`` or ``'tensorpipe'``
""" """
def __init__( def __init__(
...@@ -346,7 +344,6 @@ class DistGraphServer(KVServer): ...@@ -346,7 +344,6 @@ class DistGraphServer(KVServer):
disable_shared_mem=False, disable_shared_mem=False,
graph_format=("csc", "coo"), graph_format=("csc", "coo"),
keep_alive=False, keep_alive=False,
net_type="socket",
): ):
super(DistGraphServer, self).__init__( super(DistGraphServer, self).__init__(
server_id=server_id, server_id=server_id,
...@@ -357,7 +354,6 @@ class DistGraphServer(KVServer): ...@@ -357,7 +354,6 @@ class DistGraphServer(KVServer):
self.ip_config = ip_config self.ip_config = ip_config
self.num_servers = num_servers self.num_servers = num_servers
self.keep_alive = keep_alive self.keep_alive = keep_alive
self.net_type = net_type
# Load graph partition data. # Load graph partition data.
if self.is_backup_server(): if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards. # The backup server doesn't load the graph partition. It'll initialized afterwards.
...@@ -474,7 +470,6 @@ class DistGraphServer(KVServer): ...@@ -474,7 +470,6 @@ class DistGraphServer(KVServer):
num_servers=self.num_servers, num_servers=self.num_servers,
num_clients=self.num_clients, num_clients=self.num_clients,
server_state=server_state, server_state=server_state,
net_type=self.net_type,
) )
......
...@@ -138,32 +138,28 @@ def reset(): ...@@ -138,32 +138,28 @@ def reset():
_CAPI_DGLRPCReset() _CAPI_DGLRPCReset()
def create_sender(max_queue_size, net_type): def create_sender(max_queue_size):
"""Create rpc sender of this process. """Create rpc sender of this process.
Parameters Parameters
---------- ----------
max_queue_size : int max_queue_size : int
Maximal size (bytes) of network queue buffer. Maximal size (bytes) of network queue buffer.
net_type : str
Networking type. Current options are: 'socket', 'tensorpipe'.
""" """
max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0")) max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0"))
_CAPI_DGLRPCCreateSender(int(max_queue_size), net_type, max_thread_count) _CAPI_DGLRPCCreateSender(int(max_queue_size), "socket", max_thread_count)
def create_receiver(max_queue_size, net_type): def create_receiver(max_queue_size):
"""Create rpc receiver of this process. """Create rpc receiver of this process.
Parameters Parameters
---------- ----------
max_queue_size : int max_queue_size : int
Maximal size (bytes) of network queue buffer. Maximal size (bytes) of network queue buffer.
net_type : str
Networking type. Current options are: 'socket', 'tensorpipe'.
""" """
max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0")) max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0"))
_CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type, max_thread_count) _CAPI_DGLRPCCreateReceiver(int(max_queue_size), "socket", max_thread_count)
def finalize_sender(): def finalize_sender():
......
...@@ -113,7 +113,6 @@ def connect_to_server( ...@@ -113,7 +113,6 @@ def connect_to_server(
ip_config, ip_config,
num_servers, num_servers,
max_queue_size=MAX_QUEUE_SIZE, max_queue_size=MAX_QUEUE_SIZE,
net_type="socket",
group_id=0, group_id=0,
): ):
"""Connect this client to server. """Connect this client to server.
...@@ -128,8 +127,6 @@ def connect_to_server( ...@@ -128,8 +127,6 @@ def connect_to_server(
Maximal size (bytes) of client queue buffer (~20 GB on default). Maximal size (bytes) of client queue buffer (~20 GB on default).
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
it will not allocate 20GB memory at once. it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: 'socket', 'tensorpipe'.
group_id : int group_id : int
Indicates which group this client belongs to. Clients that are Indicates which group this client belongs to. Clients that are
booted together in each launch are gathered as a group and should booted together in each launch are gathered as a group and should
...@@ -145,9 +142,6 @@ def connect_to_server( ...@@ -145,9 +142,6 @@ def connect_to_server(
assert max_queue_size > 0, ( assert max_queue_size > 0, (
"queue_size (%d) cannot be a negative number." % max_queue_size "queue_size (%d) cannot be a negative number." % max_queue_size
) )
assert net_type in ("socket", "tensorpipe"), (
"net_type (%s) can only be 'socket' or 'tensorpipe'." % net_type
)
# Register some basic service # Register some basic service
rpc.register_service( rpc.register_service(
rpc.CLIENT_REGISTER, rpc.CLIENT_REGISTER,
...@@ -181,8 +175,8 @@ def connect_to_server( ...@@ -181,8 +175,8 @@ def connect_to_server(
machine_id = get_local_machine_id(server_namebook) machine_id = get_local_machine_id(server_namebook)
rpc.set_machine_id(machine_id) rpc.set_machine_id(machine_id)
rpc.set_group_id(group_id) rpc.set_group_id(group_id)
rpc.create_sender(max_queue_size, net_type) rpc.create_sender(max_queue_size)
rpc.create_receiver(max_queue_size, net_type) rpc.create_receiver(max_queue_size)
# Get connected with all server nodes # Get connected with all server nodes
max_try_times = int(os.environ.get("DGL_DIST_MAX_TRY_TIMES", 1024)) max_try_times = int(os.environ.get("DGL_DIST_MAX_TRY_TIMES", 1024))
for server_id, addr in server_namebook.items(): for server_id, addr in server_namebook.items():
...@@ -212,9 +206,7 @@ def connect_to_server( ...@@ -212,9 +206,7 @@ def connect_to_server(
for server_id in range(num_servers): for server_id in range(num_servers):
rpc.send_request(server_id, register_req) rpc.send_request(server_id, register_req)
# wait server connect back # wait server connect back
rpc.wait_for_senders( rpc.wait_for_senders(client_ip, client_port, num_servers, blocking=True)
client_ip, client_port, num_servers, blocking=net_type == "socket"
)
print( print(
"Client [{}] waits on {}:{}".format(os.getpid(), client_ip, client_port) "Client [{}] waits on {}:{}".format(os.getpid(), client_ip, client_port)
) )
...@@ -263,7 +255,7 @@ def shutdown_servers(ip_config, num_servers): ...@@ -263,7 +255,7 @@ def shutdown_servers(ip_config, num_servers):
rpc.register_sig_handler() rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers) server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook) num_servers = len(server_namebook)
rpc.create_sender(MAX_QUEUE_SIZE, "tensorpipe") rpc.create_sender(MAX_QUEUE_SIZE)
# Get connected with all server nodes # Get connected with all server nodes
for server_id, addr in server_namebook.items(): for server_id, addr in server_namebook.items():
server_ip = addr[1] server_ip = addr[1]
......
...@@ -15,7 +15,6 @@ def start_server( ...@@ -15,7 +15,6 @@ def start_server(
num_clients, num_clients,
server_state, server_state,
max_queue_size=MAX_QUEUE_SIZE, max_queue_size=MAX_QUEUE_SIZE,
net_type="socket",
): ):
"""Start DGL server, which will be shared with all the rpc services. """Start DGL server, which will be shared with all the rpc services.
...@@ -40,8 +39,6 @@ def start_server( ...@@ -40,8 +39,6 @@ def start_server(
Maximal size (bytes) of server queue buffer (~20 GB on default). Maximal size (bytes) of server queue buffer (~20 GB on default).
Note that the 20 GB is just an upper-bound because DGL uses zero-copy and Note that the 20 GB is just an upper-bound because DGL uses zero-copy and
it will not allocate 20GB memory at once. it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: ``'socket'`` or ``'tensorpipe'``.
""" """
assert server_id >= 0, ( assert server_id >= 0, (
"server_id (%d) cannot be a negative number." % server_id "server_id (%d) cannot be a negative number." % server_id
...@@ -55,18 +52,8 @@ def start_server( ...@@ -55,18 +52,8 @@ def start_server(
assert max_queue_size > 0, ( assert max_queue_size > 0, (
"queue_size (%d) cannot be a negative number." % max_queue_size "queue_size (%d) cannot be a negative number." % max_queue_size
) )
assert net_type in ("socket", "tensorpipe"), (
"net_type (%s) can only be 'socket' or 'tensorpipe'" % net_type
)
if server_state.keep_alive: if server_state.keep_alive:
assert ( assert False, "Long live server is not supported any more."
net_type == "tensorpipe"
), "net_type can only be 'tensorpipe' if 'keep_alive' is enabled."
print(
"As configured, this server will keep alive for multiple"
" client groups until force shutdown request is received."
" [WARNING] This feature is experimental and not fully tested."
)
# Register signal handler. # Register signal handler.
rpc.register_sig_handler() rpc.register_sig_handler()
# Register some basic services # Register some basic services
...@@ -90,17 +77,15 @@ def start_server( ...@@ -90,17 +77,15 @@ def start_server(
rpc.set_machine_id(machine_id) rpc.set_machine_id(machine_id)
ip_addr = server_namebook[server_id][1] ip_addr = server_namebook[server_id][1]
port = server_namebook[server_id][2] port = server_namebook[server_id][2]
rpc.create_sender(max_queue_size, net_type) rpc.create_sender(max_queue_size)
rpc.create_receiver(max_queue_size, net_type) rpc.create_receiver(max_queue_size)
# wait all the senders connect to server. # wait all the senders connect to server.
# Once all the senders connect to server, server will not # Once all the senders connect to server, server will not
# accept new sender's connection # accept new sender's connection
print( print(
"Server is waiting for connections on [{}:{}]...".format(ip_addr, port) "Server is waiting for connections on [{}:{}]...".format(ip_addr, port)
) )
rpc.wait_for_senders( rpc.wait_for_senders(ip_addr, port, num_clients, blocking=True)
ip_addr, port, num_clients, blocking=net_type == "socket"
)
rpc.set_num_client(num_clients) rpc.set_num_client(num_clients)
recv_clients = {} recv_clients = {}
while True: while True:
......
...@@ -49,9 +49,7 @@ class HelloRequest(dgl.distributed.Request): ...@@ -49,9 +49,7 @@ class HelloRequest(dgl.distributed.Request):
return res return res
def start_server( def start_server(server_id, ip_config, num_servers, num_clients, keep_alive):
server_id, ip_config, num_servers, num_clients, net_type, keep_alive
):
server_state = dgl.distributed.ServerState( server_state = dgl.distributed.ServerState(
None, local_g=None, partition_book=None, keep_alive=keep_alive None, local_g=None, partition_book=None, keep_alive=keep_alive
) )
...@@ -65,11 +63,10 @@ def start_server( ...@@ -65,11 +63,10 @@ def start_server(
num_servers=num_servers, num_servers=num_servers,
num_clients=num_clients, num_clients=num_clients,
server_state=server_state, server_state=server_state,
net_type=net_type,
) )
def start_client(ip_config, num_servers, group_id, net_type): def start_client(ip_config, num_servers, group_id):
dgl.distributed.register_service( dgl.distributed.register_service(
HELLO_SERVICE_ID, HelloRequest, HelloResponse HELLO_SERVICE_ID, HelloRequest, HelloResponse
) )
...@@ -77,7 +74,6 @@ def start_client(ip_config, num_servers, group_id, net_type): ...@@ -77,7 +74,6 @@ def start_client(ip_config, num_servers, group_id, net_type):
ip_config=ip_config, ip_config=ip_config,
num_servers=num_servers, num_servers=num_servers,
group_id=group_id, group_id=group_id,
net_type=net_type,
) )
req = HelloRequest(STR, INTEGER, TENSOR, tensor_func) req = HelloRequest(STR, INTEGER, TENSOR, tensor_func)
server_namebook = dgl.distributed.read_ip_config(ip_config, num_servers) server_namebook = dgl.distributed.read_ip_config(ip_config, num_servers)
...@@ -117,17 +113,14 @@ def start_client(ip_config, num_servers, group_id, net_type): ...@@ -117,17 +113,14 @@ def start_client(ip_config, num_servers, group_id, net_type):
def main(): def main():
ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG") ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG")
num_servers = int(os.environ.get("DIST_DGL_TEST_NUM_SERVERS")) num_servers = int(os.environ.get("DIST_DGL_TEST_NUM_SERVERS"))
net_type = os.environ.get("DIST_DGL_TEST_NET_TYPE", "tensorpipe")
if os.environ.get("DIST_DGL_TEST_ROLE", "server") == "server": if os.environ.get("DIST_DGL_TEST_ROLE", "server") == "server":
server_id = int(os.environ.get("DIST_DGL_TEST_SERVER_ID")) server_id = int(os.environ.get("DIST_DGL_TEST_SERVER_ID"))
num_clients = int(os.environ.get("DIST_DGL_TEST_NUM_CLIENTS")) num_clients = int(os.environ.get("DIST_DGL_TEST_NUM_CLIENTS"))
keep_alive = "DIST_DGL_TEST_KEEP_ALIVE" in os.environ keep_alive = "DIST_DGL_TEST_KEEP_ALIVE" in os.environ
start_server( start_server(server_id, ip_config, num_servers, num_clients, keep_alive)
server_id, ip_config, num_servers, num_clients, net_type, keep_alive
)
else: else:
group_id = int(os.environ.get("DIST_DGL_TEST_GROUP_ID", "0")) group_id = int(os.environ.get("DIST_DGL_TEST_GROUP_ID", "0"))
start_client(ip_config, num_servers, group_id, net_type) start_client(ip_config, num_servers, group_id)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -16,7 +16,6 @@ num_client_per_machine = int(os.environ.get("DIST_DGL_TEST_NUM_CLIENT")) ...@@ -16,7 +16,6 @@ num_client_per_machine = int(os.environ.get("DIST_DGL_TEST_NUM_CLIENT"))
shared_workspace = os.environ.get("DIST_DGL_TEST_WORKSPACE") shared_workspace = os.environ.get("DIST_DGL_TEST_WORKSPACE")
graph_path = os.environ.get("DIST_DGL_TEST_GRAPH_PATH") graph_path = os.environ.get("DIST_DGL_TEST_GRAPH_PATH")
part_id = int(os.environ.get("DIST_DGL_TEST_PART_ID")) part_id = int(os.environ.get("DIST_DGL_TEST_PART_ID"))
net_type = os.environ.get("DIST_DGL_TEST_NET_TYPE")
ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt") ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt")
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
...@@ -57,7 +56,6 @@ def run_server( ...@@ -57,7 +56,6 @@ def run_server(
disable_shared_mem=not shared_mem, disable_shared_mem=not shared_mem,
graph_format=["csc", "coo"], graph_format=["csc", "coo"],
keep_alive=keep_alive, keep_alive=keep_alive,
net_type=net_type,
) )
print("start server", server_id) print("start server", server_id)
g.start() g.start()
...@@ -780,7 +778,7 @@ if mode == "server": ...@@ -780,7 +778,7 @@ if mode == "server":
) )
elif mode == "client": elif mode == "client":
os.environ["DGL_NUM_SERVER"] = str(num_servers_per_machine) os.environ["DGL_NUM_SERVER"] = str(num_servers_per_machine)
dgl.distributed.initialize(ip_config, net_type=net_type) dgl.distributed.initialize(ip_config)
gpb, graph_name, _, _ = load_partition_book( gpb, graph_name, _, _ = load_partition_book(
graph_path + "/{}.json".format(graph_name), part_id graph_path + "/{}.json".format(graph_name), part_id
......
import os
import unittest
from utils import execute_remote, get_ips
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
def test_tensorpipe_comm():
base_dir = os.environ.get("DIST_DGL_TEST_CPP_BIN_DIR", ".")
ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt")
client_bin = os.path.join(base_dir, "rpc_client")
server_bin = os.path.join(base_dir, "rpc_server")
ips = get_ips(ip_config)
num_machines = len(ips)
procs = []
for ip in ips:
procs.append(
execute_remote(server_bin + " " + str(num_machines) + " " + ip, ip)
)
for ip in ips:
procs.append(execute_remote(client_bin + " " + ip_config, ip))
for p in procs:
p.join()
assert p.exitcode == 0
...@@ -96,12 +96,11 @@ def create_graph(num_part, dist_graph_path, hetero): ...@@ -96,12 +96,11 @@ def create_graph(num_part, dist_graph_path, hetero):
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["tensorpipe", "socket"])
@pytest.mark.parametrize("num_servers", [1, 4]) @pytest.mark.parametrize("num_servers", [1, 4])
@pytest.mark.parametrize("num_clients", [1, 4]) @pytest.mark.parametrize("num_clients", [1, 4])
@pytest.mark.parametrize("hetero", [False, True]) @pytest.mark.parametrize("hetero", [False, True])
@pytest.mark.parametrize("shared_mem", [False, True]) @pytest.mark.parametrize("shared_mem", [False, True])
def test_dist_objects(net_type, num_servers, num_clients, hetero, shared_mem): def test_dist_objects(num_servers, num_clients, hetero, shared_mem):
if not shared_mem and num_servers > 1: if not shared_mem and num_servers > 1:
pytest.skip( pytest.skip(
f"Backup servers are not supported when shared memory is disabled" f"Backup servers are not supported when shared memory is disabled"
...@@ -126,7 +125,6 @@ def test_dist_objects(net_type, num_servers, num_clients, hetero, shared_mem): ...@@ -126,7 +125,6 @@ def test_dist_objects(net_type, num_servers, num_clients, hetero, shared_mem):
f"DIST_DGL_TEST_NUM_PART={num_part} " f"DIST_DGL_TEST_NUM_PART={num_part} "
f"DIST_DGL_TEST_NUM_SERVER={num_servers} " f"DIST_DGL_TEST_NUM_SERVER={num_servers} "
f"DIST_DGL_TEST_NUM_CLIENT={num_clients} " f"DIST_DGL_TEST_NUM_CLIENT={num_clients} "
f"DIST_DGL_TEST_NET_TYPE={net_type} "
f"DIST_DGL_TEST_GRAPH_PATH={dist_graph_path} " f"DIST_DGL_TEST_GRAPH_PATH={dist_graph_path} "
f"DIST_DGL_TEST_IP_CONFIG={ip_config} " f"DIST_DGL_TEST_IP_CONFIG={ip_config} "
) )
......
...@@ -9,8 +9,7 @@ dgl_envs = f"PYTHONUNBUFFERED=1 DMLC_LOG_DEBUG=1 DGLBACKEND={os.environ.get('DGL ...@@ -9,8 +9,7 @@ dgl_envs = f"PYTHONUNBUFFERED=1 DMLC_LOG_DEBUG=1 DGLBACKEND={os.environ.get('DGL
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"]) def test_rpc():
def test_rpc(net_type):
ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt") ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt")
num_clients = 1 num_clients = 1
num_servers = 1 num_servers = 1
...@@ -21,7 +20,7 @@ def test_rpc(net_type): ...@@ -21,7 +20,7 @@ def test_rpc(net_type):
) )
base_envs = ( base_envs = (
dgl_envs dgl_envs
+ f" DGL_DIST_MODE=distributed DIST_DGL_TEST_IP_CONFIG={ip_config} DIST_DGL_TEST_NUM_SERVERS={num_servers} DIST_DGL_TEST_NET_TYPE={net_type} " + f" DGL_DIST_MODE=distributed DIST_DGL_TEST_IP_CONFIG={ip_config} DIST_DGL_TEST_NUM_SERVERS={num_servers} "
) )
procs = [] procs = []
# start server processes # start server processes
......
...@@ -135,7 +135,6 @@ def start_server( ...@@ -135,7 +135,6 @@ def start_server(
server_id=0, server_id=0,
keep_alive=False, keep_alive=False,
num_servers=1, num_servers=1,
net_type="tensorpipe",
): ):
print("Sleep 1 seconds to test client re-connect.") print("Sleep 1 seconds to test client re-connect.")
time.sleep(1) time.sleep(1)
...@@ -155,11 +154,10 @@ def start_server( ...@@ -155,11 +154,10 @@ def start_server(
num_servers=num_servers, num_servers=num_servers,
num_clients=num_clients, num_clients=num_clients,
server_state=server_state, server_state=server_state,
net_type=net_type,
) )
def start_client(ip_config, group_id=0, num_servers=1, net_type="tensorpipe"): def start_client(ip_config, group_id=0, num_servers=1):
dgl.distributed.register_service( dgl.distributed.register_service(
HELLO_SERVICE_ID, HelloRequest, HelloResponse HELLO_SERVICE_ID, HelloRequest, HelloResponse
) )
...@@ -167,7 +165,6 @@ def start_client(ip_config, group_id=0, num_servers=1, net_type="tensorpipe"): ...@@ -167,7 +165,6 @@ def start_client(ip_config, group_id=0, num_servers=1, net_type="tensorpipe"):
ip_config=ip_config, ip_config=ip_config,
num_servers=num_servers, num_servers=num_servers,
group_id=group_id, group_id=group_id,
net_type=net_type,
) )
req = HelloRequest(STR, INTEGER, TENSOR, simple_func) req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
# test send and recv # test send and recv
...@@ -202,9 +199,7 @@ def start_client(ip_config, group_id=0, num_servers=1, net_type="tensorpipe"): ...@@ -202,9 +199,7 @@ def start_client(ip_config, group_id=0, num_servers=1, net_type="tensorpipe"):
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR)) assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
def start_client_timeout( def start_client_timeout(ip_config, group_id=0, num_servers=1):
ip_config, group_id=0, num_servers=1, net_type="tensorpipe"
):
dgl.distributed.register_service( dgl.distributed.register_service(
TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse
) )
...@@ -212,7 +207,6 @@ def start_client_timeout( ...@@ -212,7 +207,6 @@ def start_client_timeout(
ip_config=ip_config, ip_config=ip_config,
num_servers=num_servers, num_servers=num_servers,
group_id=group_id, group_id=group_id,
net_type=net_type,
) )
timeout = 1 * 1000 # milliseconds timeout = 1 * 1000 # milliseconds
req = TimeoutRequest(TIMEOUT_META, timeout) req = TimeoutRequest(TIMEOUT_META, timeout)
...@@ -258,19 +252,14 @@ def start_client_timeout( ...@@ -258,19 +252,14 @@ def start_client_timeout(
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"]) def test_rpc_timeout():
def test_rpc_timeout(net_type):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
ip_config = "rpc_ip_config.txt" ip_config = "rpc_ip_config.txt"
generate_ip_config(ip_config, 1, 1) generate_ip_config(ip_config, 1, 1)
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
pserver = ctx.Process( pserver = ctx.Process(target=start_server, args=(1, ip_config, 0, False, 1))
target=start_server, args=(1, ip_config, 0, False, 1, net_type) pclient = ctx.Process(target=start_client_timeout, args=(ip_config, 0, 1))
)
pclient = ctx.Process(
target=start_client_timeout, args=(ip_config, 0, 1, net_type)
)
pserver.start() pserver.start()
pclient.start() pclient.start()
pserver.join() pserver.join()
...@@ -325,28 +314,7 @@ def test_rpc_msg(): ...@@ -325,28 +314,7 @@ def test_rpc_msg():
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["tensorpipe"]) def test_multi_client():
def test_rpc(net_type):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
generate_ip_config("rpc_ip_config.txt", 1, 1)
ctx = mp.get_context("spawn")
pserver = ctx.Process(
target=start_server,
args=(1, "rpc_ip_config.txt", 0, False, 1, net_type),
)
pclient = ctx.Process(
target=start_client, args=("rpc_ip_config.txt", 0, 1, net_type)
)
pserver.start()
pclient.start()
pserver.join()
pclient.join()
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"])
def test_multi_client(net_type):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
ip_config = "rpc_ip_config_mul_client.txt" ip_config = "rpc_ip_config_mul_client.txt"
...@@ -355,13 +323,11 @@ def test_multi_client(net_type): ...@@ -355,13 +323,11 @@ def test_multi_client(net_type):
num_clients = 20 num_clients = 20
pserver = ctx.Process( pserver = ctx.Process(
target=start_server, target=start_server,
args=(num_clients, ip_config, 0, False, 1, net_type), args=(num_clients, ip_config, 0, False, 1),
) )
pclient_list = [] pclient_list = []
for i in range(num_clients): for i in range(num_clients):
pclient = ctx.Process( pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1))
target=start_client, args=(ip_config, 0, 1, net_type)
)
pclient_list.append(pclient) pclient_list.append(pclient)
pserver.start() pserver.start()
for i in range(num_clients): for i in range(num_clients):
...@@ -372,8 +338,7 @@ def test_multi_client(net_type): ...@@ -372,8 +338,7 @@ def test_multi_client(net_type):
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"]) def test_multi_thread_rpc():
def test_multi_thread_rpc(net_type):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
num_servers = 2 num_servers = 2
...@@ -383,7 +348,7 @@ def test_multi_thread_rpc(net_type): ...@@ -383,7 +348,7 @@ def test_multi_thread_rpc(net_type):
pserver_list = [] pserver_list = []
for i in range(num_servers): for i in range(num_servers):
pserver = ctx.Process( pserver = ctx.Process(
target=start_server, args=(1, ip_config, i, False, 1, net_type) target=start_server, args=(1, ip_config, i, False, 1)
) )
pserver.start() pserver.start()
pserver_list.append(pserver) pserver_list.append(pserver)
...@@ -392,7 +357,8 @@ def test_multi_thread_rpc(net_type): ...@@ -392,7 +357,8 @@ def test_multi_thread_rpc(net_type):
import threading import threading
dgl.distributed.connect_to_server( dgl.distributed.connect_to_server(
ip_config=ip_config, num_servers=1, net_type=net_type ip_config=ip_config,
num_servers=1,
) )
dgl.distributed.register_service( dgl.distributed.register_service(
HELLO_SERVICE_ID, HelloRequest, HelloResponse HELLO_SERVICE_ID, HelloRequest, HelloResponse
...@@ -464,8 +430,7 @@ def test_multi_client_groups(): ...@@ -464,8 +430,7 @@ def test_multi_client_groups():
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"]) def test_multi_client_connect():
def test_multi_client_connect(net_type):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
ip_config = "rpc_ip_config_mul_client.txt" ip_config = "rpc_ip_config_mul_client.txt"
...@@ -474,14 +439,14 @@ def test_multi_client_connect(net_type): ...@@ -474,14 +439,14 @@ def test_multi_client_connect(net_type):
num_clients = 1 num_clients = 1
pserver = ctx.Process( pserver = ctx.Process(
target=start_server, target=start_server,
args=(num_clients, ip_config, 0, False, 1, net_type), args=(num_clients, ip_config, 0, False, 1),
) )
# small max try times # small max try times
os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1" os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1"
expect_except = False expect_except = False
try: try:
start_client(ip_config, 0, 1, net_type) start_client(ip_config, 0, 1)
except dgl.distributed.DistConnectError as err: except dgl.distributed.DistConnectError as err:
print("Expected error: {}".format(err)) print("Expected error: {}".format(err))
expect_except = True expect_except = True
...@@ -489,7 +454,7 @@ def test_multi_client_connect(net_type): ...@@ -489,7 +454,7 @@ def test_multi_client_connect(net_type):
# large max try times # large max try times
os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1024" os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1024"
pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1, net_type)) pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1))
pclient.start() pclient.start()
pserver.start() pserver.start()
pclient.join() pclient.join()
...@@ -500,9 +465,7 @@ def test_multi_client_connect(net_type): ...@@ -500,9 +465,7 @@ def test_multi_client_connect(net_type):
if __name__ == "__main__": if __name__ == "__main__":
test_serialize() test_serialize()
test_rpc_msg() test_rpc_msg()
test_rpc()
test_multi_client("socket") test_multi_client("socket")
test_multi_client("tesnsorpipe") test_multi_client("tesnsorpipe")
test_multi_thread_rpc() test_multi_thread_rpc()
test_multi_client_connect("socket") test_multi_client_connect("socket")
test_multi_client_connect("tensorpipe")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment