"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "9922f41f06e53aa5eb75a68bea615ed59d2e0b03"
Unverified Commit 5c92f6c2 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[RPC] Rpc exit with explicit invocation (#1825)

* exit client

* update

* update

* update

* update

* update

* update

* update

* update test

* update

* update

* update

* update

* update

* update

* update

* update

* update
parent 64557819
...@@ -154,8 +154,6 @@ def run(args, device, data): ...@@ -154,8 +154,6 @@ def run(args, device, data):
# clean up # clean up
if not args.standalone: if not args.standalone:
g._client.barrier() g._client.barrier()
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
def main(args): def main(args):
if not args.standalone: if not args.standalone:
......
...@@ -9,7 +9,7 @@ from .sparse_emb import SparseAdagrad, SparseNodeEmbedding ...@@ -9,7 +9,7 @@ from .sparse_emb import SparseAdagrad, SparseNodeEmbedding
from .rpc import * from .rpc import *
from .rpc_server import start_server from .rpc_server import start_server
from .rpc_client import connect_to_server, finalize_client, shutdown_servers from .rpc_client import connect_to_server, exit_client
from .kvstore import KVServer, KVClient from .kvstore import KVServer, KVClient
from .server_state import ServerState from .server_state import ServerState
from .graph_services import sample_neighbors, in_subgraph from .graph_services import sample_neighbors, in_subgraph
......
...@@ -407,7 +407,6 @@ class DistGraph: ...@@ -407,7 +407,6 @@ class DistGraph:
self._num_nodes += int(part_md['num_nodes']) self._num_nodes += int(part_md['num_nodes'])
self._num_edges += int(part_md['num_edges']) self._num_edges += int(part_md['num_edges'])
def init_ndata(self, name, shape, dtype, init_func=None): def init_ndata(self, name, shape, dtype, init_func=None):
'''Initialize node data '''Initialize node data
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import os import os
import socket import socket
import atexit
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE
...@@ -169,6 +170,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket ...@@ -169,6 +170,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.send_request(0, get_client_num_req) rpc.send_request(0, get_client_num_req)
res = rpc.recv_response() res = rpc.recv_response()
rpc.set_num_client(res.num_client) rpc.set_num_client(res.num_client)
atexit.register(exit_client)
def finalize_client(): def finalize_client():
"""Release resources of this client.""" """Release resources of this client."""
...@@ -186,3 +188,11 @@ def shutdown_servers(): ...@@ -186,3 +188,11 @@ def shutdown_servers():
req = rpc.ShutDownRequest(rpc.get_rank()) req = rpc.ShutDownRequest(rpc.get_rank())
for server_id in range(rpc.get_num_server()): for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req) rpc.send_request(server_id, req)
def exit_client():
"""Register exit callback.
"""
# Only client with rank_0 will send shutdown request to servers.
shutdown_servers()
finalize_client()
atexit.unregister(exit_client)
...@@ -164,10 +164,6 @@ def check_dist_graph(g, num_nodes, num_edges): ...@@ -164,10 +164,6 @@ def check_dist_graph(g, num_nodes, num_edges):
for n in nodes: for n in nodes:
assert n in local_nids assert n in local_nids
# clean up
if os.environ['DGL_DIST_MODE'] == 'distributed':
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
print('end') print('end')
def check_server_client(shared_mem): def check_server_client(shared_mem):
......
...@@ -27,8 +27,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem): ...@@ -27,8 +27,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb) dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dgl.distributed.shutdown_servers() dgl.distributed.exit_client()
dgl.distributed.finalize_client()
return sampled_graph return sampled_graph
...@@ -162,8 +161,7 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes): ...@@ -162,8 +161,7 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
_, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb) dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb)
sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes) sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
dgl.distributed.shutdown_servers() dgl.distributed.exit_client()
dgl.distributed.finalize_client()
return sampled_graph return sampled_graph
......
...@@ -265,8 +265,6 @@ def start_client(num_clients): ...@@ -265,8 +265,6 @@ def start_client(num_clients):
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor)) assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# clean up # clean up
kvclient.barrier() kvclient.barrier()
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet') @unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_kv_store(): def test_kv_store():
......
...@@ -154,10 +154,6 @@ def start_client(ip_config): ...@@ -154,10 +154,6 @@ def start_client(ip_config):
# clean up # clean up
time.sleep(2) time.sleep(2)
if dgl.distributed.get_rank() == 0:
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
print("Get rank: %d" % dgl.distributed.get_rank())
def test_serialize(): def test_serialize():
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment