Unverified Commit 0c313e51 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[rpc] Move test_rpc.py to distributed and use dynamic port binding (#1623)

* update

* update

* update

* update

* update

* update

* update

* update
parent 0c408603
...@@ -4,6 +4,7 @@ import dgl ...@@ -4,6 +4,7 @@ import dgl
import sys import sys
import numpy as np import numpy as np
import time import time
import socket
from scipy import sparse as spsp from scipy import sparse as spsp
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from multiprocessing import Process, Manager, Condition, Value from multiprocessing import Process, Manager, Condition, Value
...@@ -16,6 +17,35 @@ import backend as F ...@@ -16,6 +17,35 @@ import backend as F
import unittest import unittest
import pickle import pickle
if os.name != 'nt':
import fcntl
import struct
def get_local_usable_addr():
"""Get local usable IP and port
Returns
-------
str
IP address, e.g., '192.168.8.12:50051'
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(('10.255.255.255', 1))
ip_addr = sock.getsockname()[0]
except ValueError:
ip_addr = '127.0.0.1'
finally:
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
sock.listen(1)
port = sock.getsockname()[1]
sock.close()
return ip_addr + ' ' + str(port)
def create_random_graph(n): def create_random_graph(n):
arr = (spsp.random(n, n, density=0.001, format='coo') != 0).astype(np.int64) arr = (spsp.random(n, n, density=0.001, format='coo') != 0).astype(np.int64)
ig = create_graph_index(arr, readonly=True) ig = create_graph_index(arr, readonly=True)
...@@ -95,7 +125,7 @@ def test_server_client(): ...@@ -95,7 +125,7 @@ def test_server_client():
# Partition the graph # Partition the graph
num_parts = 1 num_parts = 1
graph_name = 'test' graph_name = 'dist_graph_test'
g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1) g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1) g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, '/tmp') partition_graph(g, graph_name, num_parts, '/tmp')
...@@ -126,14 +156,14 @@ def test_split(): ...@@ -126,14 +156,14 @@ def test_split():
g = create_random_graph(10000) g = create_random_graph(10000)
num_parts = 4 num_parts = 4
num_hops = 2 num_hops = 2
partition_graph(g, 'test', num_parts, '/tmp', num_hops=num_hops, part_method='metis') partition_graph(g, 'dist_graph_test', num_parts, '/tmp', num_hops=num_hops, part_method='metis')
node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30 node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30 edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
selected_nodes = np.nonzero(node_mask)[0] selected_nodes = np.nonzero(node_mask)[0]
selected_edges = np.nonzero(edge_mask)[0] selected_edges = np.nonzero(edge_mask)[0]
for i in range(num_parts): for i in range(num_parts):
part_g, node_feats, edge_feats, meta = load_partition('/tmp/test.json', i) part_g, node_feats, edge_feats, meta = load_partition('/tmp/dist_graph_test.json', i)
num_nodes, num_edges, node_map, edge_map, num_partitions = meta num_nodes, num_edges, node_map, edge_map, num_partitions = meta
gpb = GraphPartitionBook(part_id=i, gpb = GraphPartitionBook(part_id=i,
num_parts=num_partitions, num_parts=num_partitions,
...@@ -160,7 +190,8 @@ def test_split(): ...@@ -160,7 +190,8 @@ def test_split():
def prepare_dist(): def prepare_dist():
ip_config = open("kv_ip_config.txt", "w") ip_config = open("kv_ip_config.txt", "w")
ip_config.write('127.0.0.1 2500 1\n') ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr)
ip_config.close() ip_config.close()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -22,10 +22,10 @@ def test_graph_partition_book(): ...@@ -22,10 +22,10 @@ def test_graph_partition_book():
num_parts = 4 num_parts = 4
num_hops = 2 num_hops = 2
partition_graph(g, 'test', num_parts, '/tmp', num_hops=num_hops, part_method='metis') partition_graph(g, 'gpb_test', num_parts, '/tmp', num_hops=num_hops, part_method='metis')
for i in range(num_parts): for i in range(num_parts):
part_g, node_feats, edge_feats, meta = load_partition('/tmp/test.json', i) part_g, node_feats, edge_feats, meta = load_partition('/tmp/gpb_test.json', i)
num_nodes, num_edges, node_map, edge_map, num_partitions = meta num_nodes, num_edges, node_map, edge_map, num_partitions = meta
gpb = GraphPartitionBook(part_id=i, gpb = GraphPartitionBook(part_id=i,
num_parts=num_partitions, num_parts=num_partitions,
......
import os import os
import time import time
import socket
import dgl import dgl
import backend as F import backend as F
...@@ -7,11 +8,40 @@ import unittest, pytest ...@@ -7,11 +8,40 @@ import unittest, pytest
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
if os.name != 'nt':
import fcntl
import struct
INTEGER = 2 INTEGER = 2
STR = 'hello world!' STR = 'hello world!'
HELLO_SERVICE_ID = 901231 HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((10, 10), F.int64, F.cpu()) TENSOR = F.zeros((10, 10), F.int64, F.cpu())
def get_local_usable_addr():
"""Get local usable IP and port
Returns
-------
str
IP address, e.g., '192.168.8.12:50051'
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(('10.255.255.255', 1))
ip_addr = sock.getsockname()[0]
except ValueError:
ip_addr = '127.0.0.1'
finally:
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
sock.listen(1)
port = sock.getsockname()[1]
sock.close()
return ip_addr + ' ' + str(port)
def foo(x, y): def foo(x, y):
assert x == 123 assert x == 123
assert y == "abc" assert y == "abc"
...@@ -158,7 +188,8 @@ def test_rpc_msg(): ...@@ -158,7 +188,8 @@ def test_rpc_msg():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc(): def test_rpc():
ip_config = open("rpc_ip_config.txt", "w") ip_config = open("rpc_ip_config.txt", "w")
ip_config.write('127.0.0.1 30050 1\n') ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr)
ip_config.close() ip_config.close()
pid = os.fork() pid = os.fork()
if pid == 0: if pid == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment