"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "901da9dccc6de15f7f4fafbacddc1d3533114f8d"
Unverified Commit 2de80dde authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistTest] add python test of RPC (#4093)

* [DistTest] add python test of RPC

* remove return
parent c1ff4c9b
import os
import dgl
import backend as F
from numpy.testing import assert_array_equal
INTEGER = 2
STR = 'hello world!'
HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((1000, 1000), F.int64, F.cpu())
def tensor_func(tensor):
return tensor * 2
class HelloResponse(dgl.distributed.Response):
def __init__(self, hello_str, integer, tensor):
self.hello_str = hello_str
self.integer = integer
self.tensor = tensor
def __getstate__(self):
return self.hello_str, self.integer, self.tensor
def __setstate__(self, state):
self.hello_str, self.integer, self.tensor = state
class HelloRequest(dgl.distributed.Request):
def __init__(self, hello_str, integer, tensor, func):
self.hello_str = hello_str
self.integer = integer
self.tensor = tensor
self.func = func
def __getstate__(self):
return self.hello_str, self.integer, self.tensor, self.func
def __setstate__(self, state):
self.hello_str, self.integer, self.tensor, self.func = state
def process_request(self, server_state):
assert self.hello_str == STR
assert self.integer == INTEGER
new_tensor = self.func(self.tensor)
res = HelloResponse(self.hello_str, self.integer, new_tensor)
return res
def start_server(server_id, ip_config, num_servers, num_clients, net_type, keep_alive):
server_state = dgl.distributed.ServerState(
None, local_g=None, partition_book=None, keep_alive=keep_alive)
dgl.distributed.register_service(
HELLO_SERVICE_ID, HelloRequest, HelloResponse)
print("Start server {}".format(server_id))
dgl.distributed.start_server(server_id=server_id,
ip_config=ip_config,
num_servers=num_servers,
num_clients=num_clients,
server_state=server_state,
net_type=net_type)
def start_client(ip_config, num_servers, group_id, net_type):
dgl.distributed.register_service(
HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.connect_to_server(
ip_config=ip_config, num_servers=num_servers, group_id=group_id, net_type=net_type)
req = HelloRequest(STR, INTEGER, TENSOR, tensor_func)
server_namebook = dgl.distributed.read_ip_config(ip_config, num_servers)
for server_id in server_namebook.keys():
# test send and recv
dgl.distributed.send_request(server_id, req)
res = dgl.distributed.recv_response()
assert res.hello_str == STR
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# test remote_call
target_and_requests = []
for i in range(10):
target_and_requests.append((server_id, req))
res_list = dgl.distributed.remote_call(target_and_requests)
for res in res_list:
assert res.hello_str == STR
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# test send_request_to_machine
dgl.distributed.send_request_to_machine(server_id, req)
res = dgl.distributed.recv_response()
assert res.hello_str == STR
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# test remote_call_to_machine
target_and_requests = []
for i in range(10):
target_and_requests.append((server_id, req))
res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
for res in res_list:
assert res.hello_str == STR
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
def main():
ip_config = os.environ.get('DIST_DGL_TEST_IP_CONFIG')
num_servers = int(os.environ.get('DIST_DGL_TEST_NUM_SERVERS'))
net_type = os.environ.get('DIST_DGL_TEST_NET_TYPE', 'tensorpipe')
if os.environ.get('DIST_DGL_TEST_ROLE', 'server') == 'server':
server_id = int(os.environ.get('DIST_DGL_TEST_SERVER_ID'))
num_clients = int(os.environ.get('DIST_DGL_TEST_NUM_CLIENTS'))
keep_alive = 'DIST_DGL_TEST_KEEP_ALIVE' in os.environ
start_server(server_id, ip_config, num_servers,
num_clients, net_type, keep_alive)
else:
group_id = int(os.environ.get('DIST_DGL_TEST_GROUP_ID', '0'))
start_client(ip_config, num_servers, group_id, net_type)
if __name__ == '__main__':
main()
import os
import unittest
import pytest
import multiprocessing as mp
import utils
dgl_envs = f"PYTHONUNBUFFERED=1 DMLC_LOG_DEBUG=1 DGLBACKEND={os.environ.get('DGLBACKEND')} DGL_LIBRARY_PATH={os.environ.get('DGL_LIBRARY_PATH')} PYTHONPATH={os.environ.get('PYTHONPATH')} "
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("net_type", ['socket', 'tensorpipe'])
def test_rpc(net_type):
ip_config = os.environ.get('DIST_DGL_TEST_IP_CONFIG', 'ip_config.txt')
num_clients = 1
num_servers = 1
ips = utils.get_ips(ip_config)
num_machines = len(ips)
test_bin = os.path.join(os.environ.get(
'DIST_DGL_TEST_PY_BIN_DIR', '.'), 'rpc_basic.py')
base_envs = dgl_envs + \
f" DGL_DIST_MODE=distributed DIST_DGL_TEST_IP_CONFIG={ip_config} DIST_DGL_TEST_NUM_SERVERS={num_servers} DIST_DGL_TEST_NET_TYPE={net_type} "
procs = []
# start server processes
server_id = 0
for ip in ips:
for _ in range(num_servers):
server_envs = base_envs + \
f" DIST_DGL_TEST_ROLE=server DIST_DGL_TEST_SERVER_ID={server_id} DIST_DGL_TEST_NUM_CLIENTS={num_clients * num_machines} "
procs.append(utils.execute_remote(
server_envs + " python3 " + test_bin, ip))
server_id += 1
# start client processes
client_envs = base_envs + " DIST_DGL_TEST_ROLE=client DIST_DGL_TEST_GROUP_ID=0 "
for ip in ips:
for _ in range(num_clients):
procs.append(utils.execute_remote(
client_envs + " python3 "+test_bin, ip))
for p in procs:
p.join()
assert p.exitcode == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment