Unverified Commit 0c313e51 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[rpc] Move test_rpc.py to distributed and use dynamic port binding (#1623)

* update

* update

* update

* update

* update

* update

* update

* update
parent 0c408603
......@@ -4,6 +4,7 @@ import dgl
import sys
import numpy as np
import time
import socket
from scipy import sparse as spsp
from numpy.testing import assert_array_equal
from multiprocessing import Process, Manager, Condition, Value
......@@ -16,6 +17,35 @@ import backend as F
import unittest
import pickle
if os.name != 'nt':
import fcntl
import struct
def get_local_usable_addr():
"""Get local usable IP and port
Returns
-------
str
IP address, e.g., '192.168.8.12:50051'
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(('10.255.255.255', 1))
ip_addr = sock.getsockname()[0]
except ValueError:
ip_addr = '127.0.0.1'
finally:
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
sock.listen(1)
port = sock.getsockname()[1]
sock.close()
return ip_addr + ' ' + str(port)
def create_random_graph(n):
arr = (spsp.random(n, n, density=0.001, format='coo') != 0).astype(np.int64)
ig = create_graph_index(arr, readonly=True)
......@@ -95,7 +125,7 @@ def test_server_client():
# Partition the graph
num_parts = 1
graph_name = 'test'
graph_name = 'dist_graph_test'
g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, '/tmp')
......@@ -126,14 +156,14 @@ def test_split():
g = create_random_graph(10000)
num_parts = 4
num_hops = 2
partition_graph(g, 'test', num_parts, '/tmp', num_hops=num_hops, part_method='metis')
partition_graph(g, 'dist_graph_test', num_parts, '/tmp', num_hops=num_hops, part_method='metis')
node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
selected_nodes = np.nonzero(node_mask)[0]
selected_edges = np.nonzero(edge_mask)[0]
for i in range(num_parts):
part_g, node_feats, edge_feats, meta = load_partition('/tmp/test.json', i)
part_g, node_feats, edge_feats, meta = load_partition('/tmp/dist_graph_test.json', i)
num_nodes, num_edges, node_map, edge_map, num_partitions = meta
gpb = GraphPartitionBook(part_id=i,
num_parts=num_partitions,
......@@ -160,7 +190,8 @@ def test_split():
def prepare_dist():
ip_config = open("kv_ip_config.txt", "w")
ip_config.write('127.0.0.1 2500 1\n')
ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr)
ip_config.close()
if __name__ == '__main__':
......
......@@ -22,10 +22,10 @@ def test_graph_partition_book():
num_parts = 4
num_hops = 2
partition_graph(g, 'test', num_parts, '/tmp', num_hops=num_hops, part_method='metis')
partition_graph(g, 'gpb_test', num_parts, '/tmp', num_hops=num_hops, part_method='metis')
for i in range(num_parts):
part_g, node_feats, edge_feats, meta = load_partition('/tmp/test.json', i)
part_g, node_feats, edge_feats, meta = load_partition('/tmp/gpb_test.json', i)
num_nodes, num_edges, node_map, edge_map, num_partitions = meta
gpb = GraphPartitionBook(part_id=i,
num_parts=num_partitions,
......
import os
import time
import socket
import dgl
import backend as F
......@@ -7,11 +8,40 @@ import unittest, pytest
from numpy.testing import assert_array_equal
if os.name != 'nt':
import fcntl
import struct
INTEGER = 2
STR = 'hello world!'
HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((10, 10), F.int64, F.cpu())
def get_local_usable_addr():
"""Get local usable IP and port
Returns
-------
str
IP address, e.g., '192.168.8.12:50051'
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(('10.255.255.255', 1))
ip_addr = sock.getsockname()[0]
except ValueError:
ip_addr = '127.0.0.1'
finally:
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
sock.listen(1)
port = sock.getsockname()[1]
sock.close()
return ip_addr + ' ' + str(port)
def foo(x, y):
assert x == 123
assert y == "abc"
......@@ -158,7 +188,8 @@ def test_rpc_msg():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc():
ip_config = open("rpc_ip_config.txt", "w")
ip_config.write('127.0.0.1 30050 1\n')
ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr)
ip_config.close()
pid = os.fork()
if pid == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment