test_rpc.py 1.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import unittest
import pytest
import multiprocessing as mp
import utils

dgl_envs = f"PYTHONUNBUFFERED=1 DMLC_LOG_DEBUG=1 DGLBACKEND={os.environ.get('DGLBACKEND')} DGL_LIBRARY_PATH={os.environ.get('DGL_LIBRARY_PATH')} PYTHONPATH={os.environ.get('PYTHONPATH')} "


@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("net_type", ['socket', 'tensorpipe'])
def test_rpc(net_type):
    ip_config = os.environ.get('DIST_DGL_TEST_IP_CONFIG', 'ip_config.txt')
    num_clients = 1
    num_servers = 1
    ips = utils.get_ips(ip_config)
    num_machines = len(ips)
    test_bin = os.path.join(os.environ.get(
        'DIST_DGL_TEST_PY_BIN_DIR', '.'), 'rpc_basic.py')
    base_envs = dgl_envs + \
        f" DGL_DIST_MODE=distributed DIST_DGL_TEST_IP_CONFIG={ip_config} DIST_DGL_TEST_NUM_SERVERS={num_servers} DIST_DGL_TEST_NET_TYPE={net_type} "
    procs = []
    # start server processes
    server_id = 0
    for ip in ips:
        for _ in range(num_servers):
            server_envs = base_envs + \
                f" DIST_DGL_TEST_ROLE=server DIST_DGL_TEST_SERVER_ID={server_id} DIST_DGL_TEST_NUM_CLIENTS={num_clients * num_machines} "
            procs.append(utils.execute_remote(
                server_envs + " python3 " + test_bin, ip))
            server_id += 1
    # start client processes
    client_envs = base_envs + " DIST_DGL_TEST_ROLE=client DIST_DGL_TEST_GROUP_ID=0 "
    for ip in ips:
        for _ in range(num_clients):
            procs.append(utils.execute_remote(
                client_envs + " python3 "+test_bin, ip))
    for p in procs:
        p.join()
        assert p.exitcode == 0