test_rpc.py 1.71 KB
Newer Older
1
import multiprocessing as mp
2
3
import os
import unittest
4

5
6
7
8
9
10
import pytest
import utils

dgl_envs = f"PYTHONUNBUFFERED=1 DMLC_LOG_DEBUG=1 DGLBACKEND={os.environ.get('DGLBACKEND')} DGL_LIBRARY_PATH={os.environ.get('DGL_LIBRARY_PATH')} PYTHONPATH={os.environ.get('PYTHONPATH')} "


11
12
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"])
13
def test_rpc(net_type):
14
    ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt")
15
16
17
18
    num_clients = 1
    num_servers = 1
    ips = utils.get_ips(ip_config)
    num_machines = len(ips)
19
20
21
22
23
24
25
    test_bin = os.path.join(
        os.environ.get("DIST_DGL_TEST_PY_BIN_DIR", "."), "rpc_basic.py"
    )
    base_envs = (
        dgl_envs
        + f" DGL_DIST_MODE=distributed DIST_DGL_TEST_IP_CONFIG={ip_config} DIST_DGL_TEST_NUM_SERVERS={num_servers} DIST_DGL_TEST_NET_TYPE={net_type} "
    )
26
27
28
29
30
    procs = []
    # start server processes
    server_id = 0
    for ip in ips:
        for _ in range(num_servers):
31
32
33
34
35
36
37
            server_envs = (
                base_envs
                + f" DIST_DGL_TEST_ROLE=server DIST_DGL_TEST_SERVER_ID={server_id} DIST_DGL_TEST_NUM_CLIENTS={num_clients * num_machines} "
            )
            procs.append(
                utils.execute_remote(server_envs + " python3 " + test_bin, ip)
            )
38
39
            server_id += 1
    # start client processes
40
41
42
    client_envs = (
        base_envs + " DIST_DGL_TEST_ROLE=client DIST_DGL_TEST_GROUP_ID=0 "
    )
43
44
    for ip in ips:
        for _ in range(num_clients):
45
46
47
            procs.append(
                utils.execute_remote(client_envs + " python3 " + test_bin, ip)
            )
48
49
50
    for p in procs:
        p.join()
        assert p.exitcode == 0