utils.py 1.6 KB
Newer Older
1
import os
2
import random
3
4
import socket

5
import numpy as np
6
7
import scipy.sparse as spsp

8
import dgl
Jinjing Zhou's avatar
Jinjing Zhou committed
9
10


11
12
13
def generate_ip_config(file_name, num_machines, num_servers):
    """Get local IP and available ports, writes to file."""
    # get available IP in localhost
Jinjing Zhou's avatar
Jinjing Zhou committed
14
15
16
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        # doesn't even have to be reachable
17
        sock.connect(("10.255.255.255", 1))
18
        ip = sock.getsockname()[0]
Jinjing Zhou's avatar
Jinjing Zhou committed
19
    except ValueError:
20
        ip = "127.0.0.1"
Jinjing Zhou's avatar
Jinjing Zhou committed
21
22
    finally:
        sock.close()
23
24
25

    # scan available PORT
    ports = []
Jinjing Zhou's avatar
Jinjing Zhou committed
26
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
27
28
    start = random.randint(10000, 30000)
    for port in range(start, 65535):
29
30
31
32
33
34
35
        try:
            sock.connect((ip, port))
            ports = []
        except:
            ports.append(port)
            if len(ports) == num_machines * num_servers:
                break
Jinjing Zhou's avatar
Jinjing Zhou committed
36
    sock.close()
37
38
    if len(ports) < num_machines * num_servers:
        raise RuntimeError(
39
40
41
            "Failed to get available IP/PORT with required numbers."
        )
    with open(file_name, "w") as f:
42
        for i in range(num_machines):
43
            f.write("{} {}\n".format(ip, ports[i * num_servers]))
44
45
46


def reset_envs():
47
48
49
50
51
52
53
54
55
    """Reset common environment variable which are set in tests."""
    for key in [
        "DGL_ROLE",
        "DGL_NUM_SAMPLER",
        "DGL_NUM_SERVER",
        "DGL_DIST_MODE",
        "DGL_NUM_CLIENT",
        "DGL_DIST_MAX_TRY_TIMES",
    ]:
56
57
        if key in os.environ:
            os.environ.pop(key)
58
59
60
61


def create_random_graph(n):
    return dgl.rand_graph(n, int(n * n * 0.001))