import multiprocessing as mp import os import unittest import pytest import utils dgl_envs = f"PYTHONUNBUFFERED=1 DMLC_LOG_DEBUG=1 DGLBACKEND={os.environ.get('DGLBACKEND')} DGL_LIBRARY_PATH={os.environ.get('DGL_LIBRARY_PATH')} PYTHONPATH={os.environ.get('PYTHONPATH')} " @unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @pytest.mark.parametrize("net_type", ["socket", "tensorpipe"]) def test_rpc(net_type): ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt") num_clients = 1 num_servers = 1 ips = utils.get_ips(ip_config) num_machines = len(ips) test_bin = os.path.join( os.environ.get("DIST_DGL_TEST_PY_BIN_DIR", "."), "rpc_basic.py" ) base_envs = ( dgl_envs + f" DGL_DIST_MODE=distributed DIST_DGL_TEST_IP_CONFIG={ip_config} DIST_DGL_TEST_NUM_SERVERS={num_servers} DIST_DGL_TEST_NET_TYPE={net_type} " ) procs = [] # start server processes server_id = 0 for ip in ips: for _ in range(num_servers): server_envs = ( base_envs + f" DIST_DGL_TEST_ROLE=server DIST_DGL_TEST_SERVER_ID={server_id} DIST_DGL_TEST_NUM_CLIENTS={num_clients * num_machines} " ) procs.append( utils.execute_remote(server_envs + " python3 " + test_bin, ip) ) server_id += 1 # start client processes client_envs = ( base_envs + " DIST_DGL_TEST_ROLE=client DIST_DGL_TEST_GROUP_ID=0 " ) for ip in ips: for _ in range(num_clients): procs.append( utils.execute_remote(client_envs + " python3 " + test_bin, ip) ) for p in procs: p.join() assert p.exitcode == 0