Unverified Commit 5c92f6c2 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[RPC] Rpc exit with explicit invocation (#1825)

* exit client

* update

* update

* update

* update

* update

* update

* update

* update test

* update

* update

* update

* update

* update

* update

* update

* update

* update
parent 64557819
......@@ -154,8 +154,6 @@ def run(args, device, data):
# clean up
if not args.standalone:
g._client.barrier()
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
def main(args):
if not args.standalone:
......
......@@ -9,7 +9,7 @@ from .sparse_emb import SparseAdagrad, SparseNodeEmbedding
from .rpc import *
from .rpc_server import start_server
from .rpc_client import connect_to_server, finalize_client, shutdown_servers
from .rpc_client import connect_to_server, exit_client
from .kvstore import KVServer, KVClient
from .server_state import ServerState
from .graph_services import sample_neighbors, in_subgraph
......
......@@ -407,7 +407,6 @@ class DistGraph:
self._num_nodes += int(part_md['num_nodes'])
self._num_edges += int(part_md['num_edges'])
def init_ndata(self, name, shape, dtype, init_func=None):
'''Initialize node data
......
......@@ -2,6 +2,7 @@
import os
import socket
import atexit
from . import rpc
from .constants import MAX_QUEUE_SIZE
......@@ -169,6 +170,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.send_request(0, get_client_num_req)
res = rpc.recv_response()
rpc.set_num_client(res.num_client)
atexit.register(exit_client)
def finalize_client():
"""Release resources of this client."""
......@@ -186,3 +188,11 @@ def shutdown_servers():
req = rpc.ShutDownRequest(rpc.get_rank())
for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req)
def exit_client():
"""Register exit callback.
"""
# Only client with rank_0 will send shutdown request to servers.
shutdown_servers()
finalize_client()
atexit.unregister(exit_client)
......@@ -164,10 +164,6 @@ def check_dist_graph(g, num_nodes, num_edges):
for n in nodes:
assert n in local_nids
# clean up
if os.environ['DGL_DIST_MODE'] == 'distributed':
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
print('end')
def check_server_client(shared_mem):
......
......@@ -27,8 +27,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
dgl.distributed.exit_client()
return sampled_graph
......@@ -162,8 +161,7 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
_, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb)
sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
dgl.distributed.exit_client()
return sampled_graph
......
......@@ -265,8 +265,6 @@ def start_client(num_clients):
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# clean up
kvclient.barrier()
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_kv_store():
......
......@@ -154,10 +154,6 @@ def start_client(ip_config):
# clean up
time.sleep(2)
if dgl.distributed.get_rank() == 0:
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
print("Get rank: %d" % dgl.distributed.get_rank())
def test_serialize():
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment