Unverified Commit d340ea3a authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[Distributed] Remove server_count from ip_config.txt (#1985)



* remove server_count from ip_config.txt

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* lint

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update dist_context.py

* fix lint.

* make it work for multiple spaces.

* update ip_config.txt.

* fix examples.

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* udpate

* update

* update

* update

* update

* update
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-19-1.us-west-2.compute.internal>
parent d090ae86
...@@ -26,9 +26,11 @@ The data is copied to `~/graphsage/ogb-product` on each of the remote machines. ...@@ -26,9 +26,11 @@ The data is copied to `~/graphsage/ogb-product` on each of the remote machines.
specifies the location of the partitioned data in the local machine (a user only needs to specify specifies the location of the partitioned data in the local machine (a user only needs to specify
the location of the partition configuration file). the location of the partition configuration file).
```bash ```bash
python3 ~/dgl/tools/copy_partitions.py --ip_config ip_config.txt \ python3 ~/dgl/tools/copy_partitions.py \
--workspace ~/graphsage --rel_data_path ogb-product \ --ip_config ip_config.txt \
--part_config data/ogb-product.json --workspace ~/graphsage \
--rel_data_path ogb-product \
--part_config data/ogb-product.json
``` ```
**Note**: users need to make sure that the master node has right permission to ssh to all the other nodes. **Note**: users need to make sure that the master node has right permission to ssh to all the other nodes.
...@@ -45,20 +47,22 @@ python3 ~/dgl/tools/launch.py \ ...@@ -45,20 +47,22 @@ python3 ~/dgl/tools/launch.py \
--workspace ~/graphsage/ \ --workspace ~/graphsage/ \
--num_trainers 1 \ --num_trainers 1 \
--num_samplers 4 \ --num_samplers 4 \
--num_servers 1 \
--part_config ogb-product/ogb-product.json \ --part_config ogb-product/ogb-product.json \
--ip_config ip_config.txt \ --ip_config ip_config.txt \
"python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 30 --batch-size 1000 --num-workers 4" "python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-servers 1 --num-epochs 30 --batch-size 1000 --num-workers 4"
``` ```
To run unsupervised training: To run unsupervised training:
```bash ```bash
python3 ~/dgl/tools/launch.py \ python3 ~/dgl/tools/launch.py \
--workspace ~/dgl/examples/pytorch/graphsage/experimental \ --workspace ~/graphsage/ \
--num_trainers 1 \ --num_trainers 1 \
--num_servers 1 \
--part_config ogb-product/ogb-product.json \ --part_config ogb-product/ogb-product.json \
--ip_config ip_config.txt \ --ip_config ip_config.txt \
"python3 train_dist_unsupervised.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000" "python3 ~/dgl/examples/pytorch/graphsage/experimental/train_dist_unsupervised.py --graph-name ogb-product --ip_config ip_config.txt --num-servers 1 --num-epochs 3 --batch-size 1000"
``` ```
## Distributed code runs in the standalone mode ## Distributed code runs in the standalone mode
......
172.31.19.1 5555 2 172.31.19.1 5555
172.31.23.205 5555 2 172.31.23.205 5555
172.31.29.175 5555 2 172.31.29.175 5555
172.31.16.98 5555 2 172.31.16.98 5555
\ No newline at end of file
...@@ -263,7 +263,7 @@ def main(args): ...@@ -263,7 +263,7 @@ def main(args):
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend='gloo')
dgl.distributed.initialize(args.ip_config, num_workers=args.num_workers) dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config) g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config)
print('rank:', g.rank()) print('rank:', g.rank())
...@@ -295,6 +295,7 @@ if __name__ == '__main__': ...@@ -295,6 +295,7 @@ if __name__ == '__main__':
parser.add_argument('--ip_config', type=str, help='The file for IP configuration') parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--part_config', type=str, help='The path to the partition config file') parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num-client', type=int, help='The number of clients') parser.add_argument('--num-client', type=int, help='The number of clients')
parser.add_argument('--num-servers', type=int, default=1, help='The number of servers')
parser.add_argument('--n-classes', type=int, help='the number of classes') parser.add_argument('--n-classes', type=int, help='the number of classes')
parser.add_argument('--gpu', type=int, default=0, parser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training") help="GPU device ID. Use -1 for CPU training")
......
...@@ -22,7 +22,78 @@ import torch.optim as optim ...@@ -22,7 +22,78 @@ import torch.optim as optim
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
#from pyinstrument import Profiler #from pyinstrument import Profiler
from train_sampling import SAGE
class SAGE(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super().__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h
def inference(self, g, x, batch_size, device):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
g : the entire graph.
x : the input of entire node set.
The inference code is written in a fashion that it could handle any number of nodes and
layers.
"""
# During inference with sampling, multi-layer blocks are very inefficient because
# lots of computations in the first few layers are repeated.
# Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches.
# TODO: can we standardize this?
for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.sampling.MultiLayerNeighborSampler([None])
dataloader = dgl.sampling.NodeDataLoader(
g,
th.arange(g.number_of_nodes()),
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0]
block = block.int().to(device)
h = x[input_nodes].to(device)
h = layer(block, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
y[output_nodes] = h.cpu()
x = y
return y
class NegativeSampler(object): class NegativeSampler(object):
def __init__(self, g, neg_nseeds): def __init__(self, g, neg_nseeds):
...@@ -349,7 +420,7 @@ def run(args, device, data): ...@@ -349,7 +420,7 @@ def run(args, device, data):
def main(args): def main(args):
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend='gloo')
dgl.distributed.initialize(args.ip_config, num_workers=args.num_workers) dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config) g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config)
print('rank:', g.rank()) print('rank:', g.rank())
print('number of edges', g.number_of_edges()) print('number of edges', g.number_of_edges())
...@@ -381,6 +452,7 @@ if __name__ == '__main__': ...@@ -381,6 +452,7 @@ if __name__ == '__main__':
parser.add_argument('--id', type=int, help='the partition id') parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip_config', type=str, help='The file for IP configuration') parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--part_config', type=str, help='The path to the partition config file') parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num-servers', type=int, help='Server count on each machine.')
parser.add_argument('--n-classes', type=int, help='the number of classes') parser.add_argument('--n-classes', type=int, help='the number of classes')
parser.add_argument('--gpu', type=int, default=0, parser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training") help="GPU device ID. Use -1 for CPU training")
......
...@@ -21,12 +21,15 @@ if os.environ.get('DGL_ROLE', 'client') == 'server': ...@@ -21,12 +21,15 @@ if os.environ.get('DGL_ROLE', 'client') == 'server':
'Please define DGL_SERVER_ID to run DistGraph server' 'Please define DGL_SERVER_ID to run DistGraph server'
assert os.environ.get('DGL_IP_CONFIG') is not None, \ assert os.environ.get('DGL_IP_CONFIG') is not None, \
'Please define DGL_IP_CONFIG to run DistGraph server' 'Please define DGL_IP_CONFIG to run DistGraph server'
assert os.environ.get('DGL_NUM_SERVER') is not None, \
'Please define DGL_NUM_SERVER to run DistGraph server'
assert os.environ.get('DGL_NUM_CLIENT') is not None, \ assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server' 'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \ assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server' 'Please define DGL_CONF_PATH to run DistGraph server'
SERV = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')), SERV = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'), os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')), int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH')) os.environ.get('DGL_CONF_PATH'))
SERV.start() SERV.start()
......
...@@ -22,34 +22,38 @@ def set_initialized(value=True): ...@@ -22,34 +22,38 @@ def set_initialized(value=True):
global INITIALIZED global INITIALIZED
INITIALIZED = value INITIALIZED = value
def get_sampler_pool(): def get_sampler_pool():
"""Return the sampler pool and num_workers""" """Return the sampler pool and num_workers"""
return SAMPLER_POOL, NUM_SAMPLER_WORKERS return SAMPLER_POOL, NUM_SAMPLER_WORKERS
def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads):
def _init_rpc(ip_config, max_queue_size, net_type, role, num_threads):
''' This init function is called in the worker processes. ''' This init function is called in the worker processes.
''' '''
try: try:
utils.set_num_threads(num_threads) utils.set_num_threads(num_threads)
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone': if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
connect_to_server(ip_config, max_queue_size, net_type) connect_to_server(ip_config, num_servers, max_queue_size, net_type)
init_role(role) init_role(role)
init_kvstore(ip_config, role) init_kvstore(ip_config, num_servers, role)
except Exception as e: except Exception as e:
print(e, flush=True) print(e, flush=True)
traceback.print_exc() traceback.print_exc()
raise e raise e
def initialize(ip_config, num_servers=1, num_workers=0,
def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type='socket', max_queue_size=MAX_QUEUE_SIZE, net_type='socket',
num_worker_threads=1): num_worker_threads=1):
"""Init rpc service """Init rpc service
Parameters
----------
ip_config: str ip_config: str
File path of ip_config file File path of ip_config file
num_servers : int
The number of server processes on each machine
num_workers: int num_workers: int
Number of worker process to be created Number of worker process on each machine. The worker processes are used
for distributed sampling.
max_queue_size : int max_queue_size : int
Maximal size (bytes) of client queue buffer (~20 GB on default). Maximal size (bytes) of client queue buffer (~20 GB on default).
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
...@@ -66,16 +70,15 @@ def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type ...@@ -66,16 +70,15 @@ def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type
is_standalone = os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone' is_standalone = os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone'
if num_workers > 0 and not is_standalone: if num_workers > 0 and not is_standalone:
SAMPLER_POOL = ctx.Pool(num_workers, initializer=_init_rpc, SAMPLER_POOL = ctx.Pool(num_workers, initializer=_init_rpc,
initargs=(ip_config, max_queue_size, initargs=(ip_config, num_servers, max_queue_size,
net_type, 'sampler', num_worker_threads)) net_type, 'sampler', num_worker_threads))
else: else:
SAMPLER_POOL = None SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers NUM_SAMPLER_WORKERS = num_workers
if not is_standalone: if not is_standalone:
connect_to_server(ip_config, max_queue_size, net_type) connect_to_server(ip_config, num_servers, max_queue_size, net_type)
init_role('default') init_role('default')
init_kvstore(ip_config, 'default') init_kvstore(ip_config, num_servers, 'default')
def finalize_client(): def finalize_client():
"""Release resources of this client.""" """Release resources of this client."""
...@@ -85,12 +88,10 @@ def finalize_client(): ...@@ -85,12 +88,10 @@ def finalize_client():
global INITIALIZED global INITIALIZED
INITIALIZED = False INITIALIZED = False
def _exit(): def _exit():
exit_client() exit_client()
time.sleep(1) time.sleep(1)
def finalize_worker(): def finalize_worker():
"""Finalize workers """Finalize workers
Python's multiprocessing pool will not call atexit function when close Python's multiprocessing pool will not call atexit function when close
...@@ -113,7 +114,6 @@ def is_initialized(): ...@@ -113,7 +114,6 @@ def is_initialized():
""" """
return INITIALIZED return INITIALIZED
def exit_client(): def exit_client():
"""Register exit callback. """Register exit callback.
""" """
......
...@@ -241,6 +241,8 @@ class DistGraphServer(KVServer): ...@@ -241,6 +241,8 @@ class DistGraphServer(KVServer):
The server ID (start from 0). The server ID (start from 0).
ip_config : str ip_config : str
Path of IP configuration file. Path of IP configuration file.
num_servers : int
Server count on each machine.
num_clients : int num_clients : int
Total number of client nodes. Total number of client nodes.
part_config : string part_config : string
...@@ -248,10 +250,14 @@ class DistGraphServer(KVServer): ...@@ -248,10 +250,14 @@ class DistGraphServer(KVServer):
disable_shared_mem : bool disable_shared_mem : bool
Disable shared memory. Disable shared memory.
''' '''
def __init__(self, server_id, ip_config, num_clients, part_config, disable_shared_mem=False): def __init__(self, server_id, ip_config, num_servers,
super(DistGraphServer, self).__init__(server_id=server_id, ip_config=ip_config, num_clients, part_config, disable_shared_mem=False):
super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config,
num_servers=num_servers,
num_clients=num_clients) num_clients=num_clients)
self.ip_config = ip_config self.ip_config = ip_config
self.num_servers = num_servers
# Load graph partition data. # Load graph partition data.
if self.is_backup_server(): if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards. # The backup server doesn't load the graph partition. It'll initialized afterwards.
...@@ -286,7 +292,9 @@ class DistGraphServer(KVServer): ...@@ -286,7 +292,9 @@ class DistGraphServer(KVServer):
# start server # start server
server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb) server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb)
print('start graph service on server {} for part {}'.format(self.server_id, self.part_id)) print('start graph service on server {} for part {}'.format(self.server_id, self.part_id))
start_server(server_id=self.server_id, ip_config=self.ip_config, start_server(server_id=self.server_id,
ip_config=self.ip_config,
num_servers=self.num_servers,
num_clients=self.num_clients, server_state=server_state) num_clients=self.num_clients, server_state=server_state)
class DistGraph: class DistGraph:
......
...@@ -591,11 +591,14 @@ class KVServer(object): ...@@ -591,11 +591,14 @@ class KVServer(object):
ID of current server (starts from 0). ID of current server (starts from 0).
ip_config : str ip_config : str
Path of IP configuration file. Path of IP configuration file.
num_servers : int
Server count on each machine.
num_clients : int num_clients : int
Total number of KVClients that will be connected to the KVServer. Total number of KVClients that will be connected to the KVServer.
""" """
def __init__(self, server_id, ip_config, num_clients): def __init__(self, server_id, ip_config, num_servers, num_clients):
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config
assert num_clients >= 0, 'num_clients (%d) cannot be a negative number.' % num_clients assert num_clients >= 0, 'num_clients (%d) cannot be a negative number.' % num_clients
# Register services on server # Register services on server
...@@ -636,7 +639,7 @@ class KVServer(object): ...@@ -636,7 +639,7 @@ class KVServer(object):
self._part_policy = {} self._part_policy = {}
# Basic information # Basic information
self._server_id = server_id self._server_id = server_id
self._server_namebook = rpc.read_ip_config(ip_config) self._server_namebook = rpc.read_ip_config(ip_config, num_servers)
assert server_id in self._server_namebook, \ assert server_id in self._server_namebook, \
'Trying to start server {}, but there are {} servers in the config file'.format( 'Trying to start server {}, but there are {} servers in the config file'.format(
server_id, len(self._server_namebook)) server_id, len(self._server_namebook))
...@@ -773,13 +776,16 @@ class KVClient(object): ...@@ -773,13 +776,16 @@ class KVClient(object):
---------- ----------
ip_config : str ip_config : str
Path of IP configuration file. Path of IP configuration file.
num_servers : int
Server count on each machine.
role : str role : str
We can set different role for kvstore. We can set different role for kvstore.
""" """
def __init__(self, ip_config, role='default'): def __init__(self, ip_config, num_servers, role='default'):
assert rpc.get_rank() != -1, \ assert rpc.get_rank() != -1, \
'Please invoke rpc.connect_to_server() before creating KVClient.' 'Please invoke rpc.connect_to_server() before creating KVClient.'
assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
# Register services on client # Register services on client
rpc.register_service(KVSTORE_PULL, rpc.register_service(KVSTORE_PULL,
PullRequest, PullRequest,
...@@ -820,7 +826,7 @@ class KVClient(object): ...@@ -820,7 +826,7 @@ class KVClient(object):
# Store all the data name # Store all the data name
self._data_name_list = set() self._data_name_list = set()
# Basic information # Basic information
self._server_namebook = rpc.read_ip_config(ip_config) self._server_namebook = rpc.read_ip_config(ip_config, num_servers)
self._server_count = len(self._server_namebook) self._server_count = len(self._server_namebook)
self._group_count = self._server_namebook[0][3] self._group_count = self._server_namebook[0][3]
self._machine_count = int(self._server_count / self._group_count) self._machine_count = int(self._server_count / self._group_count)
...@@ -1230,14 +1236,14 @@ class KVClient(object): ...@@ -1230,14 +1236,14 @@ class KVClient(object):
KVCLIENT = None KVCLIENT = None
def init_kvstore(ip_config, role): def init_kvstore(ip_config, num_servers, role):
"""initialize KVStore""" """initialize KVStore"""
global KVCLIENT global KVCLIENT
if KVCLIENT is None: if KVCLIENT is None:
if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone': if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
KVCLIENT = SA_KVClient() KVCLIENT = SA_KVClient()
else: else:
KVCLIENT = KVClient(ip_config, role) KVCLIENT = KVClient(ip_config, num_servers, role)
def close_kvstore(): def close_kvstore():
"""Close the current KVClient""" """Close the current KVClient"""
......
...@@ -22,21 +22,28 @@ REQUEST_CLASS_TO_SERVICE_ID = {} ...@@ -22,21 +22,28 @@ REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {} RESPONSE_CLASS_TO_SERVICE_ID = {}
SERVICE_ID_TO_PROPERTY = {} SERVICE_ID_TO_PROPERTY = {}
def read_ip_config(filename): DEFUALT_PORT = 30050
def read_ip_config(filename, num_servers):
"""Read network configuration information of server from file. """Read network configuration information of server from file.
The format of configuration file should be: For exampple, the following TXT shows a 4-machine configuration:
172.31.40.143
172.31.36.140
172.31.47.147
172.31.30.180
[ip] [base_port] [server_count] Users can also set user-specified port for this network configuration. For example:
172.31.40.143 30050 2 172.31.40.143 20090
172.31.36.140 30050 2 172.31.36.140 20090
172.31.47.147 30050 2 172.31.47.147 20090
172.31.30.180 30050 2 172.31.30.180 20090
Note that, DGL supports multiple backup servers that shares data with each others Note that, DGL supports multiple backup servers that shares data with each others
on the same machine via shared-memory tensor. The server_count should be >= 1. For example, on the same machine via shared-memory tensor. The num_servers should be >= 1. For example,
if we set server_count to 5, it means that we have 1 main server and 4 backup servers on if we set num_servers to 5, it means that we have 1 main server and 4 backup servers on
current machine. current machine.
Parameters Parameters
...@@ -44,12 +51,15 @@ def read_ip_config(filename): ...@@ -44,12 +51,15 @@ def read_ip_config(filename):
filename : str filename : str
Path of IP configuration file. Path of IP configuration file.
num_servers : int
Server count on each machine.
Returns Returns
------- -------
dict dict
server namebook. server namebook.
The key is server_id (int) The key is server_id (int)
The value is [machine_id, ip, port, group_count] ([int, str, int, int]) The value is [machine_id, ip, port, num_servers] ([int, str, int, int])
e.g., e.g.,
...@@ -63,23 +73,29 @@ def read_ip_config(filename): ...@@ -63,23 +73,29 @@ def read_ip_config(filename):
7:[3, '172.31.30.180', 30051, 2]} 7:[3, '172.31.30.180', 30051, 2]}
""" """
assert len(filename) > 0, 'filename cannot be empty.' assert len(filename) > 0, 'filename cannot be empty.'
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
server_namebook = {} server_namebook = {}
try: try:
server_id = 0 server_id = 0
machine_id = 0 machine_id = 0
lines = [line.rstrip('\n') for line in open(filename)] lines = [line.rstrip('\n') for line in open(filename)]
for line in lines: for line in lines:
ip_addr, port, server_count = line.split(' ') result = line.split()
for s_count in range(int(server_count)): if len(result) == 2:
server_namebook[server_id] = \ port = int(result[1])
[int(machine_id), ip_addr, int(port)+s_count, int(server_count)] elif len(result) == 1:
port = DEFUALT_PORT
else:
raise RuntimeError('length of result can only be 1 or 2.')
ip_addr = result[0]
for s_count in range(num_servers):
server_namebook[server_id] = [machine_id, ip_addr, port+s_count, num_servers]
server_id += 1 server_id += 1
machine_id += 1 machine_id += 1
except ValueError: except RuntimeError:
print("Error: data format on each line should be: [ip] [base_port] [server_count]") print("Error: data format on each line should be: [ip] [port]")
return server_namebook return server_namebook
def reset(): def reset():
"""Reset the rpc context """Reset the rpc context
""" """
......
...@@ -97,13 +97,15 @@ def get_local_usable_addr(): ...@@ -97,13 +97,15 @@ def get_local_usable_addr():
return ip_addr + ':' + str(port) return ip_addr + ':' + str(port)
def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'): def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Connect this client to server. """Connect this client to server.
Parameters Parameters
---------- ----------
ip_config : str ip_config : str
Path of server IP configuration file. Path of server IP configuration file.
num_servers : int
server count on each machine.
max_queue_size : int max_queue_size : int
Maximal size (bytes) of client queue buffer (~20 GB on default). Maximal size (bytes) of client queue buffer (~20 GB on default).
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
...@@ -115,6 +117,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket ...@@ -115,6 +117,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
------ ------
ConnectionError : If anything wrong with the connection. ConnectionError : If anything wrong with the connection.
""" """
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'.' % net_type assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'.' % net_type
# Register some basic service # Register some basic service
...@@ -131,7 +134,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket ...@@ -131,7 +134,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.ClientBarrierRequest, rpc.ClientBarrierRequest,
rpc.ClientBarrierResponse) rpc.ClientBarrierResponse)
rpc.register_ctrl_c() rpc.register_ctrl_c()
server_namebook = rpc.read_ip_config(ip_config) server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook) num_servers = len(server_namebook)
rpc.set_num_server(num_servers) rpc.set_num_server(num_servers)
# group_count means how many servers # group_count means how many servers
......
...@@ -5,7 +5,7 @@ import time ...@@ -5,7 +5,7 @@ import time
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE
def start_server(server_id, ip_config, num_clients, server_state, \ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
max_queue_size=MAX_QUEUE_SIZE, net_type='socket'): max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Start DGL server, which will be shared with all the rpc services. """Start DGL server, which will be shared with all the rpc services.
...@@ -17,6 +17,8 @@ def start_server(server_id, ip_config, num_clients, server_state, \ ...@@ -17,6 +17,8 @@ def start_server(server_id, ip_config, num_clients, server_state, \
Current server ID (starts from 0). Current server ID (starts from 0).
ip_config : str ip_config : str
Path of IP configuration file. Path of IP configuration file.
num_servers : int
Server count on each machine.
num_clients : int num_clients : int
Total number of clients that will be connected to the server. Total number of clients that will be connected to the server.
Note that, we do not support dynamic connection for now. It means Note that, we do not support dynamic connection for now. It means
...@@ -32,6 +34,7 @@ def start_server(server_id, ip_config, num_clients, server_state, \ ...@@ -32,6 +34,7 @@ def start_server(server_id, ip_config, num_clients, server_state, \
Networking type. Current options are: 'socket'. Networking type. Current options are: 'socket'.
""" """
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_client assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_client
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % queue_size assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type
...@@ -51,7 +54,7 @@ def start_server(server_id, ip_config, num_clients, server_state, \ ...@@ -51,7 +54,7 @@ def start_server(server_id, ip_config, num_clients, server_state, \
rpc.ClientBarrierRequest, rpc.ClientBarrierRequest,
rpc.ClientBarrierResponse) rpc.ClientBarrierResponse)
rpc.set_rank(server_id) rpc.set_rank(server_id)
server_namebook = rpc.read_ip_config(ip_config) server_namebook = rpc.read_ip_config(ip_config, num_servers)
machine_id = server_namebook[server_id][0] machine_id = server_namebook[server_id][0]
rpc.set_machine_id(machine_id) rpc.set_machine_id(machine_id)
ip_addr = server_namebook[server_id][1] ip_addr = server_namebook[server_id][1]
......
...@@ -53,8 +53,8 @@ def create_random_graph(n): ...@@ -53,8 +53,8 @@ def create_random_graph(n):
arr = (spsp.random(n, n, density=0.001, format='coo', random_state=100) != 0).astype(np.int64) arr = (spsp.random(n, n, density=0.001, format='coo', random_state=100) != 0).astype(np.int64)
return dgl.graph(arr) return dgl.graph(arr)
def run_server(graph_name, server_id, num_clients, shared_mem): def run_server(graph_name, server_id, server_count, num_clients, shared_mem):
g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients, g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients, server_count,
'/tmp/dist_graph/{}.json'.format(graph_name), '/tmp/dist_graph/{}.json'.format(graph_name),
disable_shared_mem=not shared_mem) disable_shared_mem=not shared_mem)
print('start server', server_id) print('start server', server_id)
...@@ -66,9 +66,9 @@ def emb_init(shape, dtype): ...@@ -66,9 +66,9 @@ def emb_init(shape, dtype):
def rand_init(shape, dtype): def rand_init(shape, dtype):
return F.tensor(np.random.normal(size=shape), F.float32) return F.tensor(np.random.normal(size=shape), F.float32)
def run_client(graph_name, part_id, num_clients, num_nodes, num_edges): def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5) time.sleep(5)
dgl.distributed.initialize("kv_ip_config.txt") dgl.distributed.initialize("kv_ip_config.txt", server_count)
gpb, graph_name = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb, graph_name = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None) part_id, None)
g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb) g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb)
...@@ -193,7 +193,7 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges): ...@@ -193,7 +193,7 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
print('end') print('end')
def check_server_client(shared_mem, num_servers, num_clients): def check_server_client(shared_mem, num_servers, num_clients):
prepare_dist(num_servers) prepare_dist()
g = create_random_graph(10000) g = create_random_graph(10000)
# Partition the graph # Partition the graph
...@@ -208,7 +208,7 @@ def check_server_client(shared_mem, num_servers, num_clients): ...@@ -208,7 +208,7 @@ def check_server_client(shared_mem, num_servers, num_clients):
serv_ps = [] serv_ps = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
for serv_id in range(num_servers): for serv_id in range(num_servers):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
num_clients, shared_mem)) num_clients, shared_mem))
serv_ps.append(p) serv_ps.append(p)
p.start() p.start()
...@@ -216,7 +216,7 @@ def check_server_client(shared_mem, num_servers, num_clients): ...@@ -216,7 +216,7 @@ def check_server_client(shared_mem, num_servers, num_clients):
cli_ps = [] cli_ps = []
for cli_id in range(num_clients): for cli_id in range(num_clients):
print('start client', cli_id) print('start client', cli_id)
p = ctx.Process(target=run_client, args=(graph_name, 0, num_clients, g.number_of_nodes(), p = ctx.Process(target=run_client, args=(graph_name, 0, num_servers, num_clients, g.number_of_nodes(),
g.number_of_edges())) g.number_of_edges()))
p.start() p.start()
cli_ps.append(p) cli_ps.append(p)
...@@ -380,10 +380,10 @@ def test_split_even(): ...@@ -380,10 +380,10 @@ def test_split_even():
assert np.all(all_nodes == F.asnumpy(all_nodes2)) assert np.all(all_nodes == F.asnumpy(all_nodes2))
assert np.all(all_edges == F.asnumpy(all_edges2)) assert np.all(all_edges == F.asnumpy(all_edges2))
def prepare_dist(num_servers): def prepare_dist():
ip_config = open("kv_ip_config.txt", "w") ip_config = open("kv_ip_config.txt", "w")
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
ip_config.write('{} {}\n'.format(ip_addr, num_servers)) ip_config.write('{}\n'.format(ip_addr))
ip_config.close() ip_config.close()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -16,7 +16,7 @@ from dgl.distributed import DistGraphServer, DistGraph ...@@ -16,7 +16,7 @@ from dgl.distributed import DistGraphServer, DistGraph
def start_server(rank, tmpdir, disable_shared_mem, graph_name): def start_server(rank, tmpdir, disable_shared_mem, graph_name):
g = DistGraphServer(rank, "rpc_ip_config.txt", 1, g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem) tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem)
g.start() g.start()
...@@ -25,7 +25,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem): ...@@ -25,7 +25,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt", 1)
dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb) dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dgl.distributed.exit_client() dgl.distributed.exit_client()
...@@ -35,7 +35,7 @@ def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids): ...@@ -35,7 +35,7 @@ def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_find_edges.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_find_edges.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt", 1)
dist_graph = DistGraph("rpc_ip_config.txt", "test_find_edges", gpb=gpb) dist_graph = DistGraph("rpc_ip_config.txt", "test_find_edges", gpb=gpb)
u, v = find_edges(dist_graph, eids) u, v = find_edges(dist_graph, eids)
dgl.distributed.exit_client() dgl.distributed.exit_client()
...@@ -44,7 +44,7 @@ def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids): ...@@ -44,7 +44,7 @@ def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids):
def check_rpc_sampling(tmpdir, num_server): def check_rpc_sampling(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w") ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server): for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr())) ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close() ip_config.close()
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
...@@ -80,7 +80,7 @@ def check_rpc_sampling(tmpdir, num_server): ...@@ -80,7 +80,7 @@ def check_rpc_sampling(tmpdir, num_server):
def check_rpc_find_edges(tmpdir, num_server): def check_rpc_find_edges(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w") ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server): for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr())) ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close() ip_config.close()
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
...@@ -116,7 +116,7 @@ def test_rpc_sampling(): ...@@ -116,7 +116,7 @@ def test_rpc_sampling():
def check_rpc_sampling_shuffle(tmpdir, num_server): def check_rpc_sampling_shuffle(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w") ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server): for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr())) ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close() ip_config.close()
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
...@@ -175,7 +175,7 @@ def check_standalone_sampling(tmpdir): ...@@ -175,7 +175,7 @@ def check_standalone_sampling(tmpdir):
num_hops=num_hops, part_method='metis', reshuffle=False) num_hops=num_hops, part_method='metis', reshuffle=False)
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt", 1)
dist_graph = DistGraph(None, "test_sampling", part_config=tmpdir / 'test_sampling.json') dist_graph = DistGraph(None, "test_sampling", part_config=tmpdir / 'test_sampling.json')
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
...@@ -197,7 +197,7 @@ def test_standalone_sampling(): ...@@ -197,7 +197,7 @@ def test_standalone_sampling():
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes): def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
gpb = None gpb = None
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt", 1)
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb) dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb)
...@@ -209,7 +209,7 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes): ...@@ -209,7 +209,7 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
def check_rpc_in_subgraph(tmpdir, num_server): def check_rpc_in_subgraph(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w") ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server): for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr())) ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close() ip_config.close()
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
......
...@@ -40,7 +40,7 @@ class NeighborSampler(object): ...@@ -40,7 +40,7 @@ class NeighborSampler(object):
def start_server(rank, tmpdir, disable_shared_mem, num_clients): def start_server(rank, tmpdir, disable_shared_mem, num_clients):
import dgl import dgl
print('server: #clients=' + str(num_clients)) print('server: #clients=' + str(num_clients))
g = DistGraphServer(rank, "mp_ip_config.txt", num_clients, g = DistGraphServer(rank, "mp_ip_config.txt", 1, num_clients,
tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem) tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem)
g.start() g.start()
...@@ -48,7 +48,7 @@ def start_server(rank, tmpdir, disable_shared_mem, num_clients): ...@@ -48,7 +48,7 @@ def start_server(rank, tmpdir, disable_shared_mem, num_clients):
def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last): def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last):
import dgl import dgl
import torch as th import torch as th
dgl.distributed.initialize("mp_ip_config.txt", num_workers=num_workers) dgl.distributed.initialize("mp_ip_config.txt", 1, num_workers=num_workers)
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
...@@ -95,7 +95,7 @@ def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last): ...@@ -95,7 +95,7 @@ def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last):
def test_standalone(tmpdir): def test_standalone(tmpdir):
ip_config = open("mp_ip_config.txt", "w") ip_config = open("mp_ip_config.txt", "w")
for _ in range(1): for _ in range(1):
ip_config.write('{} 1\n'.format(get_local_usable_addr())) ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close() ip_config.close()
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
...@@ -119,7 +119,7 @@ def test_standalone(tmpdir): ...@@ -119,7 +119,7 @@ def test_standalone(tmpdir):
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last): def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last):
ip_config = open("mp_ip_config.txt", "w") ip_config = open("mp_ip_config.txt", "w")
for _ in range(num_server): for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr())) ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close() ip_config.close()
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
......
...@@ -103,12 +103,13 @@ def test_partition_policy(): ...@@ -103,12 +103,13 @@ def test_partition_policy():
assert node_policy.get_data_size() == len(node_map) assert node_policy.get_data_size() == len(node_map)
assert edge_policy.get_data_size() == len(edge_map) assert edge_policy.get_data_size() == len(edge_map)
def start_server(server_id, num_clients): def start_server(server_id, num_clients, num_servers):
# Init kvserver # Init kvserver
print("Sleep 5 seconds to test client re-connect.") print("Sleep 5 seconds to test client re-connect.")
time.sleep(5) time.sleep(5)
kvserver = dgl.distributed.KVServer(server_id=server_id, kvserver = dgl.distributed.KVServer(server_id=server_id,
ip_config='kv_ip_config.txt', ip_config='kv_ip_config.txt',
num_servers=num_servers,
num_clients=num_clients) num_clients=num_clients)
kvserver.add_part_policy(node_policy) kvserver.add_part_policy(node_policy)
kvserver.add_part_policy(edge_policy) kvserver.add_part_policy(edge_policy)
...@@ -126,13 +127,15 @@ def start_server(server_id, num_clients): ...@@ -126,13 +127,15 @@ def start_server(server_id, num_clients):
server_state = dgl.distributed.ServerState(kv_store=kvserver, local_g=None, partition_book=None) server_state = dgl.distributed.ServerState(kv_store=kvserver, local_g=None, partition_book=None)
dgl.distributed.start_server(server_id=server_id, dgl.distributed.start_server(server_id=server_id,
ip_config='kv_ip_config.txt', ip_config='kv_ip_config.txt',
num_servers=num_servers,
num_clients=num_clients, num_clients=num_clients,
server_state=server_state) server_state=server_state)
def start_server_mul_role(server_id, num_clients): def start_server_mul_role(server_id, num_clients, num_servers):
# Init kvserver # Init kvserver
kvserver = dgl.distributed.KVServer(server_id=server_id, kvserver = dgl.distributed.KVServer(server_id=server_id,
ip_config='kv_ip_mul_config.txt', ip_config='kv_ip_mul_config.txt',
num_servers=num_servers,
num_clients=num_clients) num_clients=num_clients)
kvserver.add_part_policy(node_policy) kvserver.add_part_policy(node_policy)
if kvserver.is_backup_server(): if kvserver.is_backup_server():
...@@ -143,15 +146,16 @@ def start_server_mul_role(server_id, num_clients): ...@@ -143,15 +146,16 @@ def start_server_mul_role(server_id, num_clients):
server_state = dgl.distributed.ServerState(kv_store=kvserver, local_g=None, partition_book=None) server_state = dgl.distributed.ServerState(kv_store=kvserver, local_g=None, partition_book=None)
dgl.distributed.start_server(server_id=server_id, dgl.distributed.start_server(server_id=server_id,
ip_config='kv_ip_mul_config.txt', ip_config='kv_ip_mul_config.txt',
num_servers=num_servers,
num_clients=num_clients, num_clients=num_clients,
server_state=server_state) server_state=server_state)
def start_client(num_clients): def start_client(num_clients, num_servers):
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
# Note: connect to server first ! # Note: connect to server first !
dgl.distributed.initialize(ip_config='kv_ip_config.txt') dgl.distributed.initialize(ip_config='kv_ip_config.txt', num_servers=num_servers)
# Init kvclient # Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt') kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt', num_servers=num_servers)
kvclient.map_shared_data(partition_book=gpb) kvclient.map_shared_data(partition_book=gpb)
assert dgl.distributed.get_num_client() == num_clients assert dgl.distributed.get_num_client() == num_clients
kvclient.init_data(name='data_1', kvclient.init_data(name='data_1',
...@@ -276,10 +280,10 @@ def start_client(num_clients): ...@@ -276,10 +280,10 @@ def start_client(num_clients):
data_tensor = data_tensor * num_clients data_tensor = data_tensor * num_clients
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor)) assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
def start_client_mul_role(i, num_workers): def start_client_mul_role(i, num_workers, num_servers):
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
# Initialize creates kvstore ! # Initialize creates kvstore !
dgl.distributed.initialize(ip_config='kv_ip_mul_config.txt', num_workers=num_workers) dgl.distributed.initialize(ip_config='kv_ip_mul_config.txt', num_servers=num_servers, num_workers=num_workers)
if i == 0: # block one trainer if i == 0: # block one trainer
time.sleep(5) time.sleep(5)
kvclient = dgl.distributed.kvstore.get_kvstore() kvclient = dgl.distributed.kvstore.get_kvstore()
...@@ -298,17 +302,17 @@ def test_kv_store(): ...@@ -298,17 +302,17 @@ def test_kv_store():
num_servers = 2 num_servers = 2
num_clients = 2 num_clients = 2
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
ip_config.write('{} {}\n'.format(ip_addr, num_servers)) ip_config.write('{}\n'.format(ip_addr))
ip_config.close() ip_config.close()
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
pserver_list = [] pserver_list = []
pclient_list = [] pclient_list = []
for i in range(num_servers): for i in range(num_servers):
pserver = ctx.Process(target=start_server, args=(i, num_clients)) pserver = ctx.Process(target=start_server, args=(i, num_clients, num_servers))
pserver.start() pserver.start()
pserver_list.append(pserver) pserver_list.append(pserver)
for i in range(num_clients): for i in range(num_clients):
pclient = ctx.Process(target=start_client, args=(num_clients,)) pclient = ctx.Process(target=start_client, args=(num_clients, num_servers))
pclient.start() pclient.start()
pclient_list.append(pclient) pclient_list.append(pclient)
for i in range(num_clients): for i in range(num_clients):
...@@ -325,17 +329,17 @@ def test_kv_multi_role(): ...@@ -325,17 +329,17 @@ def test_kv_multi_role():
# There are two trainer processes and each trainer process has two sampler processes. # There are two trainer processes and each trainer process has two sampler processes.
num_clients = num_trainers * (1 + num_samplers) num_clients = num_trainers * (1 + num_samplers)
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
ip_config.write('{} {}\n'.format(ip_addr, num_servers)) ip_config.write('{}\n'.format(ip_addr))
ip_config.close() ip_config.close()
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
pserver_list = [] pserver_list = []
pclient_list = [] pclient_list = []
for i in range(num_servers): for i in range(num_servers):
pserver = ctx.Process(target=start_server_mul_role, args=(i, num_clients)) pserver = ctx.Process(target=start_server_mul_role, args=(i, num_clients, num_servers))
pserver.start() pserver.start()
pserver_list.append(pserver) pserver_list.append(pserver)
for i in range(num_trainers): for i in range(num_trainers):
pclient = ctx.Process(target=start_client_mul_role, args=(i, num_samplers)) pclient = ctx.Process(target=start_client_mul_role, args=(i, num_samplers, num_servers))
pclient.start() pclient.start()
pclient_list.append(pclient) pclient_list.append(pclient)
for i in range(num_trainers): for i in range(num_trainers):
......
...@@ -114,12 +114,13 @@ def start_server(num_clients, ip_config): ...@@ -114,12 +114,13 @@ def start_server(num_clients, ip_config):
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.start_server(server_id=0, dgl.distributed.start_server(server_id=0,
ip_config=ip_config, ip_config=ip_config,
num_servers=1,
num_clients=num_clients, num_clients=num_clients,
server_state=server_state) server_state=server_state)
def start_client(ip_config): def start_client(ip_config):
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.connect_to_server(ip_config=ip_config) dgl.distributed.connect_to_server(ip_config=ip_config, num_servers=1)
req = HelloRequest(STR, INTEGER, TENSOR, simple_func) req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
# test send and recv # test send and recv
dgl.distributed.send_request(0, req) dgl.distributed.send_request(0, req)
...@@ -191,7 +192,7 @@ def test_rpc(): ...@@ -191,7 +192,7 @@ def test_rpc():
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = open("rpc_ip_config.txt", "w") ip_config = open("rpc_ip_config.txt", "w")
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr) ip_config.write('%s\n' % ip_addr)
ip_config.close() ip_config.close()
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt")) pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt"))
...@@ -207,7 +208,7 @@ def test_multi_client(): ...@@ -207,7 +208,7 @@ def test_multi_client():
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = open("rpc_ip_config_mul_client.txt", "w") ip_config = open("rpc_ip_config_mul_client.txt", "w")
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr) ip_config.write('%s\n' % ip_addr)
ip_config.close() ip_config.close()
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server, args=(10, "rpc_ip_config_mul_client.txt")) pserver = ctx.Process(target=start_server, args=(10, "rpc_ip_config_mul_client.txt"))
......
...@@ -36,10 +36,10 @@ def main(): ...@@ -36,10 +36,10 @@ def main():
hosts = [] hosts = []
with open(args.ip_config) as f: with open(args.ip_config) as f:
for line in f: for line in f:
ip, _, _ = line.strip().split(' ') res = line.strip().split(' ')
ip = res[0]
hosts.append(ip) hosts.append(ip)
# We need to update the partition config file so that the paths are relative to # We need to update the partition config file so that the paths are relative to
# the workspace in the remote machines. # the workspace in the remote machines.
with open(args.part_config) as conf_f: with open(args.part_config) as conf_f:
......
...@@ -10,9 +10,11 @@ import time ...@@ -10,9 +10,11 @@ import time
import json import json
from threading import Thread from threading import Thread
def execute_remote(cmd, ip, thread_list): DEFAULT_PORT = 30050
def execute_remote(cmd, ip, port, thread_list):
"""execute command line on remote machine via ssh""" """execute command line on remote machine via ssh"""
cmd = 'ssh -o StrictHostKeyChecking=no ' + ip + ' \'' + cmd + '\'' cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'' + cmd + '\''
# thread func to run the job # thread func to run the job
def run(cmd): def run(cmd):
subprocess.check_call(cmd, shell = True) subprocess.check_call(cmd, shell = True)
...@@ -32,12 +34,18 @@ def submit_jobs(args, udf_command): ...@@ -32,12 +34,18 @@ def submit_jobs(args, udf_command):
ip_config = args.workspace + '/' + args.ip_config ip_config = args.workspace + '/' + args.ip_config
with open(ip_config) as f: with open(ip_config) as f:
for line in f: for line in f:
ip, port, count = line.strip().split(' ') result = line.strip().split()
port = int(port) if len(result) == 2:
count = int(count) ip = result[0]
server_count_per_machine = count port = int(result[1])
hosts.append((ip, port)) hosts.append((ip, port))
elif len(result) == 1:
ip = result[0]
port = DEFAULT_PORT
hosts.append((ip, port))
else:
raise RuntimeError("Format error of ip_config.")
server_count_per_machine = args.num_servers
# Get partition info of the graph data # Get partition info of the graph data
part_config = args.workspace + '/' + args.part_config part_config = args.workspace + '/' + args.part_config
with open(part_config) as conf_f: with open(part_config) as conf_f:
...@@ -54,17 +62,19 @@ def submit_jobs(args, udf_command): ...@@ -54,17 +62,19 @@ def submit_jobs(args, udf_command):
server_cmd = server_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(tot_num_clients) server_cmd = server_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(tot_num_clients)
server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config) server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config)
server_cmd = server_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config) server_cmd = server_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config)
server_cmd = server_cmd + ' ' + 'DGL_NUM_SERVER=' + str(args.num_servers)
for i in range(len(hosts)*server_count_per_machine): for i in range(len(hosts)*server_count_per_machine):
ip, _ = hosts[int(i / server_count_per_machine)] ip, _ = hosts[int(i / server_count_per_machine)]
cmd = server_cmd + ' ' + 'DGL_SERVER_ID=' + str(i) cmd = server_cmd + ' ' + 'DGL_SERVER_ID=' + str(i)
cmd = cmd + ' ' + udf_command cmd = cmd + ' ' + udf_command
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = 'cd ' + str(args.workspace) + '; ' + cmd
execute_remote(cmd, ip, thread_list) execute_remote(cmd, ip, args.ssh_port, thread_list)
# launch client tasks # launch client tasks
client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client' client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client'
client_cmd = client_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(tot_num_clients) client_cmd = client_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(tot_num_clients)
client_cmd = client_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config) client_cmd = client_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config)
client_cmd = client_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config) client_cmd = client_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config)
client_cmd = client_cmd + ' ' + 'DGL_NUM_SERVER=' + str(args.num_servers)
if os.environ.get('OMP_NUM_THREADS') is not None: if os.environ.get('OMP_NUM_THREADS') is not None:
client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + os.environ.get('OMP_NUM_THREADS') client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + os.environ.get('OMP_NUM_THREADS')
if os.environ.get('PYTHONPATH') is not None: if os.environ.get('PYTHONPATH') is not None:
...@@ -87,13 +97,14 @@ def submit_jobs(args, udf_command): ...@@ -87,13 +97,14 @@ def submit_jobs(args, udf_command):
new_udf_command = udf_command.replace('python', 'python ' + new_torch_cmd) new_udf_command = udf_command.replace('python', 'python ' + new_torch_cmd)
cmd = client_cmd + ' ' + new_udf_command cmd = client_cmd + ' ' + new_udf_command
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = 'cd ' + str(args.workspace) + '; ' + cmd
execute_remote(cmd, ip, thread_list) execute_remote(cmd, ip, args.ssh_port, thread_list)
for thread in thread_list: for thread in thread_list:
thread.join() thread.join()
def main(): def main():
parser = argparse.ArgumentParser(description='Launch a distributed job') parser = argparse.ArgumentParser(description='Launch a distributed job')
parser.add_argument('--ssh_port', type=int, default=22, help='SSH Port.')
parser.add_argument('--workspace', type=str, parser.add_argument('--workspace', type=str,
help='Path of user directory of distributed tasks. \ help='Path of user directory of distributed tasks. \
This is used to specify a destination location where \ This is used to specify a destination location where \
...@@ -106,6 +117,8 @@ def main(): ...@@ -106,6 +117,8 @@ def main():
help='The file (in workspace) of the partition config') help='The file (in workspace) of the partition config')
parser.add_argument('--ip_config', type=str, parser.add_argument('--ip_config', type=str,
help='The file (in workspace) of IP configuration for server processes') help='The file (in workspace) of IP configuration for server processes')
parser.add_argument('--num_servers', type=int,
help='Server count on each machine.')
parser.add_argument('--num_server_threads', type=int, default=1, parser.add_argument('--num_server_threads', type=int, default=1,
help='The number of OMP threads in the server process. \ help='The number of OMP threads in the server process. \
It should be small if server processes and trainer processes run on \ It should be small if server processes and trainer processes run on \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment