test_rpc.py 1.61 KB
Newer Older
1
import multiprocessing as mp
2
3
import os
import unittest
4

5
6
7
8
9
10
import pytest
import utils

dgl_envs = f"PYTHONUNBUFFERED=1 DMLC_LOG_DEBUG=1 DGLBACKEND={os.environ.get('DGLBACKEND')} DGL_LIBRARY_PATH={os.environ.get('DGL_LIBRARY_PATH')} PYTHONPATH={os.environ.get('PYTHONPATH')} "


11
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
12
def test_rpc():
13
    ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt")
14
15
16
17
    num_clients = 1
    num_servers = 1
    ips = utils.get_ips(ip_config)
    num_machines = len(ips)
18
19
20
21
22
    test_bin = os.path.join(
        os.environ.get("DIST_DGL_TEST_PY_BIN_DIR", "."), "rpc_basic.py"
    )
    base_envs = (
        dgl_envs
23
        + f" DGL_DIST_MODE=distributed DIST_DGL_TEST_IP_CONFIG={ip_config} DIST_DGL_TEST_NUM_SERVERS={num_servers} "
24
    )
25
26
27
28
29
    procs = []
    # start server processes
    server_id = 0
    for ip in ips:
        for _ in range(num_servers):
30
31
32
33
34
35
36
            server_envs = (
                base_envs
                + f" DIST_DGL_TEST_ROLE=server DIST_DGL_TEST_SERVER_ID={server_id} DIST_DGL_TEST_NUM_CLIENTS={num_clients * num_machines} "
            )
            procs.append(
                utils.execute_remote(server_envs + " python3 " + test_bin, ip)
            )
37
38
            server_id += 1
    # start client processes
39
40
41
    client_envs = (
        base_envs + " DIST_DGL_TEST_ROLE=client DIST_DGL_TEST_GROUP_ID=0 "
    )
42
43
    for ip in ips:
        for _ in range(num_clients):
44
45
46
            procs.append(
                utils.execute_remote(client_envs + " python3 " + test_bin, ip)
            )
47
48
49
    for p in procs:
        p.join()
        assert p.exitcode == 0