Unverified Commit 02e4cd8b authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] long live server for multiple client groups (#3645)

* [Feature] long live server for multiple client groups

* generate globally unique name for DistTensor within DGL automatically
parent 2b98e764
...@@ -23,7 +23,7 @@ from . import optim ...@@ -23,7 +23,7 @@ from . import optim
from .rpc import * from .rpc import *
from .rpc_server import start_server from .rpc_server import start_server
from .rpc_client import connect_to_server from .rpc_client import connect_to_server, shutdown_servers
from .dist_context import initialize, exit_client from .dist_context import initialize, exit_client
from .kvstore import KVServer, KVClient from .kvstore import KVServer, KVClient
from .server_state import ServerState from .server_state import ServerState
......
...@@ -2,3 +2,6 @@ ...@@ -2,3 +2,6 @@
# Maximum size of message queue in bytes # Maximum size of message queue in bytes
MAX_QUEUE_SIZE = 20*1024*1024*1024 MAX_QUEUE_SIZE = 20*1024*1024*1024
SERVER_EXIT = "server_exit"
SERVER_KEEP_ALIVE = "server_keep_alive"
...@@ -8,12 +8,13 @@ import time ...@@ -8,12 +8,13 @@ import time
import os import os
import sys import sys
import queue import queue
import gc
from enum import Enum from enum import Enum
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE
from .kvstore import init_kvstore, close_kvstore from .kvstore import init_kvstore, close_kvstore
from .rpc_client import connect_to_server, shutdown_servers from .rpc_client import connect_to_server
from .role import init_role from .role import init_role
from .. import utils from .. import utils
...@@ -33,13 +34,13 @@ def get_sampler_pool(): ...@@ -33,13 +34,13 @@ def get_sampler_pool():
return SAMPLER_POOL, NUM_SAMPLER_WORKERS return SAMPLER_POOL, NUM_SAMPLER_WORKERS
def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads): def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads, group_id):
''' This init function is called in the worker processes. ''' This init function is called in the worker processes.
''' '''
try: try:
utils.set_num_threads(num_threads) utils.set_num_threads(num_threads)
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone': if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
connect_to_server(ip_config, num_servers, max_queue_size, net_type) connect_to_server(ip_config, num_servers, max_queue_size, net_type, group_id)
init_role(role) init_role(role)
init_kvstore(ip_config, num_servers, role) init_kvstore(ip_config, num_servers, role)
except Exception as e: except Exception as e:
...@@ -227,12 +228,14 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -227,12 +228,14 @@ def initialize(ip_config, num_servers=1, num_workers=0,
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',') formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats] formats = [f.strip() for f in formats]
rpc.reset() rpc.reset()
keep_alive = os.environ.get('DGL_KEEP_ALIVE') is not None
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')), serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'), os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')), int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')), int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'), os.environ.get('DGL_CONF_PATH'),
graph_format=formats) graph_format=formats,
keep_alive=keep_alive)
serv.start() serv.start()
sys.exit() sys.exit()
else: else:
...@@ -244,7 +247,7 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -244,7 +247,7 @@ def initialize(ip_config, num_servers=1, num_workers=0,
num_servers = int(os.environ.get('DGL_NUM_SERVER')) num_servers = int(os.environ.get('DGL_NUM_SERVER'))
else: else:
num_servers = 1 num_servers = 1
group_id = int(os.environ.get('DGL_GROUP_ID', 0))
rpc.reset() rpc.reset()
global SAMPLER_POOL global SAMPLER_POOL
global NUM_SAMPLER_WORKERS global NUM_SAMPLER_WORKERS
...@@ -252,14 +255,15 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -252,14 +255,15 @@ def initialize(ip_config, num_servers=1, num_workers=0,
'DGL_DIST_MODE', 'standalone') == 'standalone' 'DGL_DIST_MODE', 'standalone') == 'standalone'
if num_workers > 0 and not is_standalone: if num_workers > 0 and not is_standalone:
SAMPLER_POOL = CustomPool(num_workers, (ip_config, num_servers, max_queue_size, SAMPLER_POOL = CustomPool(num_workers, (ip_config, num_servers, max_queue_size,
net_type, 'sampler', num_worker_threads)) net_type, 'sampler', num_worker_threads,
group_id))
else: else:
SAMPLER_POOL = None SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers NUM_SAMPLER_WORKERS = num_workers
if not is_standalone: if not is_standalone:
assert num_servers is not None and num_servers > 0, \ assert num_servers is not None and num_servers > 0, \
'The number of servers per machine must be specified with a positive number.' 'The number of servers per machine must be specified with a positive number.'
connect_to_server(ip_config, num_servers, max_queue_size, net_type) connect_to_server(ip_config, num_servers, max_queue_size, net_type, group_id=group_id)
init_role('default') init_role('default')
init_kvstore(ip_config, num_servers, 'default') init_kvstore(ip_config, num_servers, 'default')
...@@ -299,6 +303,14 @@ def is_initialized(): ...@@ -299,6 +303,14 @@ def is_initialized():
return INITIALIZED return INITIALIZED
def _shutdown_servers():
set_initialized(False)
# send ShutDownRequest to servers
if rpc.get_rank() == 0: # Only client_0 issue this command
req = rpc.ShutDownRequest(rpc.get_rank())
for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req)
def exit_client(): def exit_client():
"""Trainer exits """Trainer exits
...@@ -311,9 +323,11 @@ def exit_client(): ...@@ -311,9 +323,11 @@ def exit_client():
""" """
# Only client with rank_0 will send shutdown request to servers. # Only client with rank_0 will send shutdown request to servers.
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
# collect data such as DistTensor before exit
gc.collect()
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone': if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
rpc.client_barrier() rpc.client_barrier()
shutdown_servers() _shutdown_servers()
finalize_client() finalize_client()
join_finalize_worker() join_finalize_worker()
close_kvstore() close_kvstore()
......
...@@ -191,7 +191,7 @@ class NodeDataView(MutableMapping): ...@@ -191,7 +191,7 @@ class NodeDataView(MutableMapping):
dtype, shape, _ = g._client.get_data_meta(str(name)) dtype, shape, _ = g._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore. # We create a wrapper on the existing tensor in the kvstore.
self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(), self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
part_policy=policy) part_policy=policy, attach=False)
def _get_names(self): def _get_names(self):
return list(self._data.keys()) return list(self._data.keys())
...@@ -245,7 +245,7 @@ class EdgeDataView(MutableMapping): ...@@ -245,7 +245,7 @@ class EdgeDataView(MutableMapping):
dtype, shape, _ = g._client.get_data_meta(str(name)) dtype, shape, _ = g._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore. # We create a wrapper on the existing tensor in the kvstore.
self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(), self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
part_policy=policy) part_policy=policy, attach=False)
def _get_names(self): def _get_names(self):
return list(self._data.keys()) return list(self._data.keys())
...@@ -308,16 +308,19 @@ class DistGraphServer(KVServer): ...@@ -308,16 +308,19 @@ class DistGraphServer(KVServer):
Disable shared memory. Disable shared memory.
graph_format : str or list of str graph_format : str or list of str
The graph formats. The graph formats.
keep_alive : bool
Whether to keep server alive when clients exit
''' '''
def __init__(self, server_id, ip_config, num_servers, def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False, num_clients, part_config, disable_shared_mem=False,
graph_format=('csc', 'coo')): graph_format=('csc', 'coo'), keep_alive=False):
super(DistGraphServer, self).__init__(server_id=server_id, super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config, ip_config=ip_config,
num_servers=num_servers, num_servers=num_servers,
num_clients=num_clients) num_clients=num_clients)
self.ip_config = ip_config self.ip_config = ip_config
self.num_servers = num_servers self.num_servers = num_servers
self.keep_alive = keep_alive
# Load graph partition data. # Load graph partition data.
if self.is_backup_server(): if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards. # The backup server doesn't load the graph partition. It'll initialized afterwards.
...@@ -351,6 +354,7 @@ class DistGraphServer(KVServer): ...@@ -351,6 +354,7 @@ class DistGraphServer(KVServer):
data_name = HeteroDataName(True, ntype, feat_name) data_name = HeteroDataName(True, ntype, feat_name)
self.init_data(name=str(data_name), policy_str=data_name.policy_str, self.init_data(name=str(data_name), policy_str=data_name.policy_str,
data_tensor=node_feats[name]) data_tensor=node_feats[name])
self.orig_data.add(str(data_name))
for name in edge_feats: for name in edge_feats:
# The feature name has the following format: edge_type + "/" + feature_name to avoid # The feature name has the following format: edge_type + "/" + feature_name to avoid
# feature name collision for different edge types. # feature name collision for different edge types.
...@@ -358,13 +362,16 @@ class DistGraphServer(KVServer): ...@@ -358,13 +362,16 @@ class DistGraphServer(KVServer):
data_name = HeteroDataName(False, etype, feat_name) data_name = HeteroDataName(False, etype, feat_name)
self.init_data(name=str(data_name), policy_str=data_name.policy_str, self.init_data(name=str(data_name), policy_str=data_name.policy_str,
data_tensor=edge_feats[name]) data_tensor=edge_feats[name])
self.orig_data.add(str(data_name))
def start(self): def start(self):
""" Start graph store server. """ Start graph store server.
""" """
# start server # start server
server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb) server_state = ServerState(kv_store=self, local_g=self.client_g,
print('start graph service on server {} for part {}'.format(self.server_id, self.part_id)) partition_book=self.gpb, keep_alive=self.keep_alive)
print('start graph service on server {} for part {}'.format(
self.server_id, self.part_id))
start_server(server_id=self.server_id, start_server(server_id=self.server_id,
ip_config=self.ip_config, ip_config=self.ip_config,
num_servers=self.num_servers, num_servers=self.num_servers,
......
...@@ -7,6 +7,7 @@ from .kvstore import get_kvstore ...@@ -7,6 +7,7 @@ from .kvstore import get_kvstore
from .role import get_role from .role import get_role
from .. import utils from .. import utils
from .. import backend as F from .. import backend as F
from .rpc import get_group_id
def _default_init_data(shape, dtype): def _default_init_data(shape, dtype):
return F.zeros(shape, dtype, F.cpu()) return F.zeros(shape, dtype, F.cpu())
...@@ -80,6 +81,8 @@ class DistTensor: ...@@ -80,6 +81,8 @@ class DistTensor:
Whether the created tensor lives after the ``DistTensor`` object is destroyed. Whether the created tensor lives after the ``DistTensor`` object is destroyed.
is_gdata : bool is_gdata : bool
Whether the created tensor is a ndata/edata or not. Whether the created tensor is a ndata/edata or not.
attach : bool
Whether to attach group ID into name to be globally unique.
Examples Examples
-------- --------
...@@ -102,12 +105,13 @@ class DistTensor: ...@@ -102,12 +105,13 @@ class DistTensor:
do the same. do the same.
''' '''
def __init__(self, shape, dtype, name=None, init_func=None, part_policy=None, def __init__(self, shape, dtype, name=None, init_func=None, part_policy=None,
persistent=False, is_gdata=True): persistent=False, is_gdata=True, attach=True):
self.kvstore = get_kvstore() self.kvstore = get_kvstore()
assert self.kvstore is not None, \ assert self.kvstore is not None, \
'Distributed module is not initialized. Please call dgl.distributed.initialize.' 'Distributed module is not initialized. Please call dgl.distributed.initialize.'
self._shape = shape self._shape = shape
self._dtype = dtype self._dtype = dtype
self._attach = attach
part_policies = self.kvstore.all_possible_part_policy part_policies = self.kvstore.all_possible_part_policy
# If a user doesn't provide a partition policy, we should find one based on # If a user doesn't provide a partition policy, we should find one based on
...@@ -128,7 +132,6 @@ class DistTensor: ...@@ -128,7 +132,6 @@ class DistTensor:
+ 'its first dimension does not match the number of nodes or edges ' \ + 'its first dimension does not match the number of nodes or edges ' \
+ 'of a distributed graph or there does not exist a distributed graph.' + 'of a distributed graph or there does not exist a distributed graph.'
self._tensor_name = name
self._part_policy = part_policy self._part_policy = part_policy
assert part_policy.get_size() == shape[0], \ assert part_policy.get_size() == shape[0], \
'The partition policy does not match the input shape.' 'The partition policy does not match the input shape.'
...@@ -146,6 +149,8 @@ class DistTensor: ...@@ -146,6 +149,8 @@ class DistTensor:
name = 'anonymous-' + get_role() + '-' + str(DIST_TENSOR_ID) name = 'anonymous-' + get_role() + '-' + str(DIST_TENSOR_ID)
DIST_TENSOR_ID += 1 DIST_TENSOR_ID += 1
assert isinstance(name, str), 'name {} is type {}'.format(name, type(name)) assert isinstance(name, str), 'name {} is type {}'.format(name, type(name))
name = self._attach_group_id(name)
self._tensor_name = name
data_name = part_policy.get_data_name(name) data_name = part_policy.get_data_name(name)
self._name = str(data_name) self._name = str(data_name)
self._persistent = persistent self._persistent = persistent
...@@ -220,7 +225,7 @@ class DistTensor: ...@@ -220,7 +225,7 @@ class DistTensor:
str str
The name of the tensor. The name of the tensor.
''' '''
return self._name return self._detach_group_id(self._name)
@property @property
def tensor_name(self): def tensor_name(self):
...@@ -231,7 +236,7 @@ class DistTensor: ...@@ -231,7 +236,7 @@ class DistTensor:
str str
The name of the tensor. The name of the tensor.
''' '''
return self._tensor_name return self._detach_group_id(self._tensor_name)
def count_nonzero(self): def count_nonzero(self):
'''Count and return the number of nonzero value '''Count and return the number of nonzero value
...@@ -241,4 +246,29 @@ class DistTensor: ...@@ -241,4 +246,29 @@ class DistTensor:
int int
the number of nonzero value the number of nonzero value
''' '''
return self.kvstore.count_nonzero(name=self.name) return self.kvstore.count_nonzero(name=self._name)
def _attach_group_id(self, name):
"""Attach group ID if needed
Returns
-------
str
new name with group ID attached
"""
if not self._attach:
return name
return "{}_{}".format(name, get_group_id())
def _detach_group_id(self, name):
"""Detach group ID if needed
Returns
-------
str
original name without group ID
"""
if not self._attach:
return name
suffix = "_{}".format(get_group_id())
return name[:-len(suffix)]
...@@ -206,20 +206,23 @@ class BarrierRequest(rpc.Request): ...@@ -206,20 +206,23 @@ class BarrierRequest(rpc.Request):
""" """
def __init__(self, role): def __init__(self, role):
self.role = role self.role = role
self.group_id = rpc.get_group_id()
def __getstate__(self): def __getstate__(self):
return self.role return self.role, self.group_id
def __setstate__(self, state): def __setstate__(self, state):
self.role = state self.role, self.group_id = state
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
role = server_state.roles roles = server_state.roles
count = kv_store.barrier_count[self.role] role = roles[self.group_id]
kv_store.barrier_count[self.role] = count + 1 barrier_count = kv_store.barrier_count[self.group_id]
if kv_store.barrier_count[self.role] == len(role[self.role]): count = barrier_count[self.role]
kv_store.barrier_count[self.role] = 0 barrier_count[self.role] = count + 1
if barrier_count[self.role] == len(role[self.role]):
barrier_count[self.role] = 0
res_list = [] res_list = []
for client_id, _ in role[self.role]: for client_id, _ in role[self.role]:
res_list.append((client_id, BarrierResponse(BARRIER_MSG))) res_list.append((client_id, BarrierResponse(BARRIER_MSG)))
...@@ -362,6 +365,9 @@ class GetSharedDataRequest(rpc.Request): ...@@ -362,6 +365,9 @@ class GetSharedDataRequest(rpc.Request):
meta = {} meta = {}
kv_store = server_state.kv_store kv_store = server_state.kv_store
for name, data in kv_store.data_store.items(): for name, data in kv_store.data_store.items():
if server_state.keep_alive:
if name not in kv_store.orig_data:
continue
meta[name] = (F.shape(data), meta[name] = (F.shape(data),
F.reverse_data_type_dict[F.dtype(data)], F.reverse_data_type_dict[F.dtype(data)],
kv_store.part_policy[name].policy_str) kv_store.part_policy[name].policy_str)
...@@ -671,6 +677,8 @@ class KVServer(object): ...@@ -671,6 +677,8 @@ class KVServer(object):
CountLocalNonzeroResponse) CountLocalNonzeroResponse)
# Store the tensor data with specified data name # Store the tensor data with specified data name
self._data_store = {} self._data_store = {}
# Store original tensor data names when instantiating DistGraphServer
self._orig_data = set()
# Store the partition information with specified data name # Store the partition information with specified data name
self._policy_set = set() self._policy_set = set()
self._part_policy = {} self._part_policy = {}
...@@ -715,6 +723,11 @@ class KVServer(object): ...@@ -715,6 +723,11 @@ class KVServer(object):
"""Get data store""" """Get data store"""
return self._data_store return self._data_store
@property
def orig_data(self):
"""Get original data"""
return self._orig_data
@property @property
def part_policy(self): def part_policy(self):
"""Get part policy""" """Get part policy"""
......
...@@ -39,20 +39,22 @@ class RegisterRoleRequest(rpc.Request): ...@@ -39,20 +39,22 @@ class RegisterRoleRequest(rpc.Request):
self.client_id = client_id self.client_id = client_id
self.machine_id = machine_id self.machine_id = machine_id
self.role = role self.role = role
self.group_id = rpc.get_group_id()
def __getstate__(self): def __getstate__(self):
return self.client_id, self.machine_id, self.role return self.client_id, self.machine_id, self.role, self.group_id
def __setstate__(self, state): def __setstate__(self, state):
self.client_id, self.machine_id, self.role = state self.client_id, self.machine_id, self.role, self.group_id = state
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
role = server_state.roles role = server_state.roles.setdefault(self.group_id, {})
if self.role not in role: if self.role not in role:
role[self.role] = set() role[self.role] = set()
if kv_store is not None: if kv_store is not None:
kv_store.barrier_count[self.role] = 0 barrier_count = kv_store.barrier_count.setdefault(self.group_id, {})
barrier_count[self.role] = 0
role[self.role].add((self.client_id, self.machine_id)) role[self.role].add((self.client_id, self.machine_id))
total_count = 0 total_count = 0
for key in role: for key in role:
...@@ -84,15 +86,16 @@ class GetRoleRequest(rpc.Request): ...@@ -84,15 +86,16 @@ class GetRoleRequest(rpc.Request):
"""Send a request to get the roles of all client processes.""" """Send a request to get the roles of all client processes."""
def __init__(self): def __init__(self):
self.msg = GET_ROLE_MSG self.msg = GET_ROLE_MSG
self.group_id = rpc.get_group_id()
def __getstate__(self): def __getstate__(self):
return self.msg return self.msg, self.group_id
def __setstate__(self, state): def __setstate__(self, state):
self.msg = state self.msg, self.group_id = state
def process_request(self, server_state): def process_request(self, server_state):
return GetRoleResponse(server_state.roles) return GetRoleResponse(server_state.roles[self.group_id])
# The key is role, the value is a dict of mapping RPC rank to a rank within the role. # The key is role, the value is a dict of mapping RPC rank to a rank within the role.
PER_ROLE_RANK = {} PER_ROLE_RANK = {}
......
...@@ -6,6 +6,8 @@ import pickle ...@@ -6,6 +6,8 @@ import pickle
import random import random
import numpy as np import numpy as np
from .constants import SERVER_EXIT, SERVER_KEEP_ALIVE
from .._ffi.object import register_object, ObjectBase from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import DGLError from ..base import DGLError
...@@ -156,7 +158,7 @@ def receiver_wait(ip_addr, port, num_senders, blocking=True): ...@@ -156,7 +158,7 @@ def receiver_wait(ip_addr, port, num_senders, blocking=True):
""" """
_CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders), blocking) _CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders), blocking)
def connect_receiver(ip_addr, port, recv_id): def connect_receiver(ip_addr, port, recv_id, group_id=-1):
"""Connect to target receiver """Connect to target receiver
Parameters Parameters
...@@ -168,7 +170,10 @@ def connect_receiver(ip_addr, port, recv_id): ...@@ -168,7 +170,10 @@ def connect_receiver(ip_addr, port, recv_id):
recv_id : int recv_id : int
receiver's ID receiver's ID
""" """
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(recv_id)) target_id = recv_id if group_id == -1 else register_client(recv_id, group_id)
if target_id < 0:
raise DGLError("Invalid target id: {}".format(target_id))
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id))
def set_rank(rank): def set_rank(rank):
"""Set the rank of this process. """Set the rank of this process.
...@@ -497,8 +502,10 @@ class RPCMessage(ObjectBase): ...@@ -497,8 +502,10 @@ class RPCMessage(ObjectBase):
Payload buffer carried by this request. Payload buffer carried by this request.
tensors : list[tensor] tensors : list[tensor]
Extra payloads in the form of tensors. Extra payloads in the form of tensors.
group_id : int
The group ID
""" """
def __init__(self, service_id, msg_seq, client_id, server_id, data, tensors): def __init__(self, service_id, msg_seq, client_id, server_id, data, tensors, group_id=0):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLRPCCreateRPCMessage, _CAPI_DGLRPCCreateRPCMessage,
int(service_id), int(service_id),
...@@ -506,7 +513,8 @@ class RPCMessage(ObjectBase): ...@@ -506,7 +513,8 @@ class RPCMessage(ObjectBase):
int(client_id), int(client_id),
int(server_id), int(server_id),
data, data,
[F.zerocopy_to_dgl_ndarray(tsor) for tsor in tensors]) [F.zerocopy_to_dgl_ndarray(tsor) for tsor in tensors],
int(group_id))
@property @property
def service_id(self): def service_id(self):
...@@ -539,6 +547,11 @@ class RPCMessage(ObjectBase): ...@@ -539,6 +547,11 @@ class RPCMessage(ObjectBase):
rst = _CAPI_DGLRPCMessageGetTensors(self) rst = _CAPI_DGLRPCMessageGetTensors(self)
return [F.zerocopy_from_dgl_ndarray(tsor) for tsor in rst] return [F.zerocopy_from_dgl_ndarray(tsor) for tsor in rst]
@property
def group_id(self):
"""Get group ID."""
return _CAPI_DGLRPCMessageGetGroupId(self)
def send_request(target, request): def send_request(target, request):
"""Send one request to the target server. """Send one request to the target server.
...@@ -566,7 +579,8 @@ def send_request(target, request): ...@@ -566,7 +579,8 @@ def send_request(target, request):
client_id = get_rank() client_id = get_rank()
server_id = target server_id = target
data, tensors = serialize_to_payload(request) data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors) msg = RPCMessage(service_id, msg_seq, client_id, server_id,
data, tensors, group_id=get_group_id())
send_rpc_message(msg, server_id) send_rpc_message(msg, server_id)
def send_request_to_machine(target, request): def send_request_to_machine(target, request):
...@@ -595,10 +609,10 @@ def send_request_to_machine(target, request): ...@@ -595,10 +609,10 @@ def send_request_to_machine(target, request):
server_id = random.randint(target*get_num_server_per_machine(), server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1) (target+1)*get_num_server_per_machine()-1)
data, tensors = serialize_to_payload(request) data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors) msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id())
send_rpc_message(msg, server_id) send_rpc_message(msg, server_id)
def send_response(target, response): def send_response(target, response, group_id):
"""Send one response to the target client. """Send one response to the target client.
Serialize the given response object to an :class:`RPCMessage` and send it Serialize the given response object to an :class:`RPCMessage` and send it
...@@ -615,6 +629,8 @@ def send_response(target, response): ...@@ -615,6 +629,8 @@ def send_response(target, response):
ID of target client. ID of target client.
response : Response response : Response
The response to send. The response to send.
group_id : int
Group ID of target client.
Raises Raises
------ ------
...@@ -625,8 +641,8 @@ def send_response(target, response): ...@@ -625,8 +641,8 @@ def send_response(target, response):
client_id = target client_id = target
server_id = get_rank() server_id = get_rank()
data, tensors = serialize_to_payload(response) data, tensors = serialize_to_payload(response)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors) msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, group_id)
send_rpc_message(msg, client_id) send_rpc_message(msg, get_client(client_id, group_id))
def recv_request(timeout=0): def recv_request(timeout=0):
"""Receive one request. """Receive one request.
...@@ -647,6 +663,8 @@ def recv_request(timeout=0): ...@@ -647,6 +663,8 @@ def recv_request(timeout=0):
One request received from the target, or None if it times out. One request received from the target, or None if it times out.
client_id : int client_id : int
Client' ID received from the target. Client' ID received from the target.
group_id : int
Group' ID received from the target.
Raises Raises
------ ------
...@@ -665,7 +683,7 @@ def recv_request(timeout=0): ...@@ -665,7 +683,7 @@ def recv_request(timeout=0):
if msg.server_id != get_rank(): if msg.server_id != get_rank():
raise DGLError('Got request sent to server {}, ' raise DGLError('Got request sent to server {}, '
'different from my rank {}!'.format(msg.server_id, get_rank())) 'different from my rank {}!'.format(msg.server_id, get_rank()))
return req, msg.client_id return req, msg.client_id, msg.group_id
def recv_response(timeout=0): def recv_response(timeout=0):
"""Receive one response. """Receive one response.
...@@ -699,8 +717,11 @@ def recv_response(timeout=0): ...@@ -699,8 +717,11 @@ def recv_response(timeout=0):
'but no response class is registered.'.format(msg.service_id)) 'but no response class is registered.'.format(msg.service_id))
res = deserialize_from_payload(res_cls, msg.data, msg.tensors) res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != get_rank() and get_rank() != -1: if msg.client_id != get_rank() and get_rank() != -1:
raise DGLError('Got reponse of request sent by client {}, ' raise DGLError('Got response of request sent by client {}, '
'different from my rank {}!'.format(msg.client_id, get_rank())) 'different from my rank {}!'.format(msg.client_id, get_rank()))
if msg.group_id != get_group_id():
raise DGLError("Got response of request sent by group {}, "
"different from my group {}!".format(msg.group_id, get_group_id()))
return res return res
def remote_call(target_and_requests, timeout=0): def remote_call(target_and_requests, timeout=0):
...@@ -742,7 +763,7 @@ def remote_call(target_and_requests, timeout=0): ...@@ -742,7 +763,7 @@ def remote_call(target_and_requests, timeout=0):
server_id = random.randint(target*get_num_server_per_machine(), server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1) (target+1)*get_num_server_per_machine()-1)
data, tensors = serialize_to_payload(request) data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors) msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id())
send_rpc_message(msg, server_id) send_rpc_message(msg, server_id)
# check if has response # check if has response
res_cls = get_service_property(service_id)[1] res_cls = get_service_property(service_id)[1]
...@@ -792,7 +813,7 @@ def send_requests_to_machine(target_and_requests): ...@@ -792,7 +813,7 @@ def send_requests_to_machine(target_and_requests):
server_id = random.randint(target*get_num_server_per_machine(), server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1) (target+1)*get_num_server_per_machine()-1)
data, tensors = serialize_to_payload(request) data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors) msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id())
send_rpc_message(msg, server_id) send_rpc_message(msg, server_id)
# check if has response # check if has response
res_cls = get_service_property(service_id)[1] res_cls = get_service_property(service_id)[1]
...@@ -1050,19 +1071,22 @@ class ShutDownRequest(Request): ...@@ -1050,19 +1071,22 @@ class ShutDownRequest(Request):
client_id : int client_id : int
client's ID client's ID
""" """
def __init__(self, client_id): def __init__(self, client_id, force_shutdown_server=False):
self.client_id = client_id self.client_id = client_id
self.force_shutdown_server = force_shutdown_server
def __getstate__(self): def __getstate__(self):
return self.client_id return self.client_id, self.force_shutdown_server
def __setstate__(self, state): def __setstate__(self, state):
self.client_id = state self.client_id, self.force_shutdown_server = state
def process_request(self, server_state): def process_request(self, server_state):
assert self.client_id == 0 assert self.client_id == 0
if server_state.keep_alive and not self.force_shutdown_server:
return SERVER_KEEP_ALIVE
finalize_server() finalize_server()
return 'exit' return SERVER_EXIT
GET_NUM_CLIENT = 22453 GET_NUM_CLIENT = 22453
...@@ -1133,21 +1157,69 @@ class ClientBarrierRequest(Request): ...@@ -1133,21 +1157,69 @@ class ClientBarrierRequest(Request):
""" """
def __init__(self, msg='barrier'): def __init__(self, msg='barrier'):
self.msg = msg self.msg = msg
self.group_id = get_group_id()
def __getstate__(self): def __getstate__(self):
return self.msg return self.msg, self.group_id
def __setstate__(self, state): def __setstate__(self, state):
self.msg = state self.msg, self.group_id = state
def process_request(self, server_state): def process_request(self, server_state):
_CAPI_DGLRPCSetBarrierCount(_CAPI_DGLRPCGetBarrierCount()+1) _CAPI_DGLRPCSetBarrierCount(_CAPI_DGLRPCGetBarrierCount(self.group_id)+1, self.group_id)
if _CAPI_DGLRPCGetBarrierCount() == get_num_client(): if _CAPI_DGLRPCGetBarrierCount(self.group_id) == get_num_client():
_CAPI_DGLRPCSetBarrierCount(0) _CAPI_DGLRPCSetBarrierCount(0, self.group_id)
res_list = [] res_list = []
for target_id in range(get_num_client()): for target_id in range(get_num_client()):
res_list.append((target_id, ClientBarrierResponse())) res_list.append((target_id, ClientBarrierResponse()))
return res_list return res_list
return None return None
def set_group_id(group_id):
"""Set current group ID
Parameters
----------
group_id : int
Current group ID
"""
_CAPI_DGLRPCSetGroupID(int(group_id))
def get_group_id():
"""Get current group ID
Returns
-------
int
group ID
"""
return _CAPI_DGLRPCGetGroupID()
def register_client(client_id, group_id):
"""Register client
Returns
-------
int
unique client ID
"""
return _CAPI_DGLRPCRegisterClient(int(client_id), int(group_id))
def get_client(client_id, group_id):
"""Get global client ID
Parameters
----------
client_id : int
client ID
group_id : int
group ID
Returns
-------
int
global client ID
"""
return _CAPI_DGLRPCGetClient(int(client_id), int(group_id))
_init_api("dgl.distributed.rpc") _init_api("dgl.distributed.rpc")
...@@ -103,7 +103,8 @@ def get_local_usable_addr(probe_addr): ...@@ -103,7 +103,8 @@ def get_local_usable_addr(probe_addr):
return ip_addr + ':' + str(port) return ip_addr + ':' + str(port)
def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'): def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
net_type='socket', group_id=0):
"""Connect this client to server. """Connect this client to server.
Parameters Parameters
...@@ -118,6 +119,10 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net ...@@ -118,6 +119,10 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
it will not allocate 20GB memory at once. it will not allocate 20GB memory at once.
net_type : str net_type : str
Networking type. Current options are: 'socket'. Networking type. Current options are: 'socket'.
group_id : int
Indicates which group this client belongs to. Clients that are
booted together in each launch are gathered as a group and should
have same unique group_id.
Raises Raises
------ ------
...@@ -156,6 +161,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net ...@@ -156,6 +161,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
rpc.set_num_machines(num_machines) rpc.set_num_machines(num_machines)
machine_id = get_local_machine_id(server_namebook) machine_id = get_local_machine_id(server_namebook)
rpc.set_machine_id(machine_id) rpc.set_machine_id(machine_id)
rpc.set_group_id(group_id)
rpc.create_sender(max_queue_size, net_type) rpc.create_sender(max_queue_size, net_type)
rpc.create_receiver(max_queue_size, net_type) rpc.create_receiver(max_queue_size, net_type)
# Get connected with all server nodes # Get connected with all server nodes
...@@ -169,6 +175,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net ...@@ -169,6 +175,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
client_ip, client_port = ip_addr.split(':') client_ip, client_port = ip_addr.split(':')
# wait server connect back # wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers, blocking=False) rpc.receiver_wait(client_ip, client_port, num_servers, blocking=False)
print("Client [{}] waits on {}:{}".format(os.getpid(), client_ip, client_port))
# Register client on server # Register client on server
register_req = rpc.ClientRegisterRequest(ip_addr) register_req = rpc.ClientRegisterRequest(ip_addr)
for server_id in range(num_servers): for server_id in range(num_servers):
...@@ -176,8 +183,8 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net ...@@ -176,8 +183,8 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
# recv client ID from server # recv client ID from server
res = rpc.recv_response() res = rpc.recv_response()
rpc.set_rank(res.client_id) rpc.set_rank(res.client_id)
print("Machine (%d) client (%d) connect to server successfuly!" \ print("Machine (%d) group (%d) client (%d) connect to server successfuly!" \
% (machine_id, rpc.get_rank())) % (machine_id, group_id, rpc.get_rank()))
# get total number of client # get total number of client
get_client_num_req = rpc.GetNumberClientsRequest(rpc.get_rank()) get_client_num_req = rpc.GetNumberClientsRequest(rpc.get_rank())
rpc.send_request(0, get_client_num_req) rpc.send_request(0, get_client_num_req)
...@@ -187,16 +194,44 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net ...@@ -187,16 +194,44 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
atexit.register(exit_client) atexit.register(exit_client)
set_initialized(True) set_initialized(True)
def shutdown_servers(): def shutdown_servers(ip_config, num_servers):
"""Issue commands to remote servers to shut them down. """Issue commands to remote servers to shut them down.
This function is required to be called manually only when we
have booted servers which keep alive even clients exit. In
order to shut down server elegantly, we utilize existing
client logic/code to boot a special client which does nothing
but send shut down request to servers. Once such request is
received, servers will exit from endless wait loop, release
occupied resources and end its process. Please call this function
with same arguments used in `dgl.distributed.connect_to_server`.
Parameters
----------
ip_config : str
Path of server IP configuration file.
num_servers : int
server count on each machine.
Raises Raises
------ ------
ConnectionError : If anything wrong with the connection. ConnectionError : If anything wrong with the connection.
""" """
from .dist_context import set_initialized rpc.register_service(rpc.SHUT_DOWN_SERVER,
set_initialized(False) rpc.ShutDownRequest,
if rpc.get_rank() == 0: # Only client_0 issue this command None)
req = rpc.ShutDownRequest(rpc.get_rank()) rpc.register_sig_handler()
for server_id in range(rpc.get_num_server()): server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook)
rpc.create_sender(MAX_QUEUE_SIZE, 'socket')
# Get connected with all server nodes
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
while not rpc.connect_receiver(server_ip, server_port, server_id):
time.sleep(1)
# send ShutDownRequest to all servers
req = rpc.ShutDownRequest(0, True)
for server_id in range(num_servers):
rpc.send_request(server_id, req) rpc.send_request(server_id, req)
rpc.finalize_sender()
"""Functions used by server.""" """Functions used by server."""
import time import time
from ..base import DGLError
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE
def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
max_queue_size=MAX_QUEUE_SIZE, net_type='socket'): max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
...@@ -37,6 +38,9 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -37,6 +38,9 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_clients assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_clients
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type
if server_state.keep_alive:
print("As configured, this server will keep alive for multiple"
" client groups until force shutdown request is received.")
# Register signal handler. # Register signal handler.
rpc.register_sig_handler() rpc.register_sig_handler()
# Register some basic services # Register some basic services
...@@ -63,39 +67,52 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -63,39 +67,52 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# wait all the senders connect to server. # wait all the senders connect to server.
# Once all the senders connect to server, server will not # Once all the senders connect to server, server will not
# accept new sender's connection # accept new sender's connection
print("Wait connections non-blockingly...") print(
"Server is waiting for connections non-blockingly on [{}:{}]...".format(ip_addr, port))
rpc.receiver_wait(ip_addr, port, num_clients, blocking=False) rpc.receiver_wait(ip_addr, port, num_clients, blocking=False)
rpc.set_num_client(num_clients) rpc.set_num_client(num_clients)
# Recv all the client's IP and assign ID to clients recv_clients = {}
addr_list = [] while True:
client_namebook = {} # go through if any client group is ready for connection
for _ in range(num_clients): for group_id in list(recv_clients.keys()):
# blocked until request is received ips = recv_clients[group_id]
req, _ = rpc.recv_request() if len(ips) < rpc.get_num_client():
assert isinstance(req, rpc.ClientRegisterRequest) continue
addr_list.append(req.ip_addr) else:
addr_list.sort() del recv_clients[group_id]
for client_id, addr in enumerate(addr_list): # a new client group is ready
client_namebook[client_id] = addr ips.sort()
client_namebook = {client_id:addr for client_id, addr in enumerate(ips)}
for client_id, addr in client_namebook.items(): for client_id, addr in client_namebook.items():
client_ip, client_port = addr.split(':') client_ip, client_port = addr.split(':')
# TODO[Rhett]: server should not be blocked endlessly. # TODO[Rhett]: server should not be blocked endlessly.
while not rpc.connect_receiver(client_ip, client_port, client_id): while not rpc.connect_receiver(client_ip, client_port, client_id, group_id):
time.sleep(1) time.sleep(1)
if rpc.get_rank() == 0: # server_0 send all the IDs if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items(): for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id) register_res = rpc.ClientRegisterResponse(client_id)
rpc.send_response(client_id, register_res) rpc.send_response(client_id, register_res, group_id)
# main service loop # receive incomming client requests
while True: req, client_id, group_id = rpc.recv_request()
req, client_id = rpc.recv_request() if isinstance(req, rpc.ClientRegisterRequest):
if group_id not in recv_clients:
recv_clients[group_id] = []
recv_clients[group_id].append(req.ip_addr)
continue
res = req.process_request(server_state) res = req.process_request(server_state)
if res is not None: if res is not None:
if isinstance(res, list): if isinstance(res, list):
for response in res: for response in res:
target_id, res_data = response target_id, res_data = response
rpc.send_response(target_id, res_data) rpc.send_response(target_id, res_data, group_id)
elif isinstance(res, str) and res == 'exit': elif isinstance(res, str):
break # break the loop and exit server if res == SERVER_EXIT:
print("Server is exiting...")
return
elif res == SERVER_KEEP_ALIVE:
print("Server keeps alive while client group~{} is exiting...".format(group_id))
else:
raise DGLError("Unexpected response: {}".format(res))
else: else:
rpc.send_response(client_id, res) rpc.send_response(client_id, res, group_id)
...@@ -38,12 +38,15 @@ class ServerState: ...@@ -38,12 +38,15 @@ class ServerState:
Total number of edges Total number of edges
partition_book : GraphPartitionBook partition_book : GraphPartitionBook
Graph Partition book Graph Partition book
keep_alive : bool
whether to keep alive which supports any number of client groups connect
""" """
def __init__(self, kv_store, local_g, partition_book): def __init__(self, kv_store, local_g, partition_book, keep_alive=False):
self._kv_store = kv_store self._kv_store = kv_store
self._graph = local_g self._graph = local_g
self.partition_book = partition_book self.partition_book = partition_book
self._keep_alive = keep_alive
self._roles = {} self._roles = {}
@property @property
...@@ -69,5 +72,9 @@ class ServerState: ...@@ -69,5 +72,9 @@ class ServerState:
def graph(self, graph): def graph(self, graph):
self._graph = graph self._graph = graph
@property
def keep_alive(self):
"""Flag of whether keep alive"""
return self._keep_alive
_init_api("dgl.distributed.server_state") _init_api("dgl.distributed.server_state")
...@@ -223,13 +223,19 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq") ...@@ -223,13 +223,19 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->barrier_count; const int32_t group_id = args[0];
auto&& cnt = RPCContext::getInstance()->barrier_count;
if (cnt.find(group_id) == cnt.end()) {
cnt.emplace(group_id, 0x0);
}
*rv = cnt[group_id];
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t count = args[0]; const int32_t count = args[0];
RPCContext::getInstance()->barrier_count = count; const int32_t group_id = args[1];
RPCContext::getInstance()->barrier_count[group_id] = count;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
...@@ -296,6 +302,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage") ...@@ -296,6 +302,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
args[4]; // directly assigning string value raises errors :( args[4]; // directly assigning string value raises errors :(
rst->data = data; rst->data = data;
rst->tensors = ListValueToVector<NDArray>(args[5]); rst->tensors = ListValueToVector<NDArray>(args[5]);
rst->group_id = args[6];
*rv = rst; *rv = rst;
}); });
...@@ -464,6 +471,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") ...@@ -464,6 +471,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
msg.data = pickle_data; msg.data = pickle_data;
NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]); NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
msg.tensors.push_back(tensor); msg.tensors.push_back(tensor);
msg.group_id = RPCContext::getInstance()->group_id;
SendRPCMessage(msg, msg.server_id); SendRPCMessage(msg, msg.server_id);
msg_count++; msg_count++;
} }
...@@ -499,6 +507,37 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") ...@@ -499,6 +507,37 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
*rv = res_tensor; *rv = res_tensor;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->group_id;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t group_id = args[0];
RPCContext::getInstance()->group_id = group_id;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
*rv = msg->group_id;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t client_id = args[0];
const int32_t group_id = args[1];
*rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t client_id = args[0];
const int32_t group_id = args[1];
*rv = RPCContext::getInstance()->GetClient(client_id, group_id);
});
} // namespace rpc } // namespace rpc
} // namespace dgl } // namespace dgl
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <mutex> #include <mutex>
#include <unordered_map>
#include "./rpc_msg.h" #include "./rpc_msg.h"
#include "./tensorpipe/tp_communicator.h" #include "./tensorpipe/tp_communicator.h"
...@@ -68,7 +69,7 @@ struct RPCContext { ...@@ -68,7 +69,7 @@ struct RPCContext {
/*! /*!
* \brief Current barrier count * \brief Current barrier count
*/ */
int32_t barrier_count = 0; std::unordered_map<int32_t, int32_t> barrier_count;
/*! /*!
* \brief Total number of server per machine. * \brief Total number of server per machine.
...@@ -101,6 +102,13 @@ struct RPCContext { ...@@ -101,6 +102,13 @@ struct RPCContext {
*/ */
std::shared_ptr<ServerState> server_state; std::shared_ptr<ServerState> server_state;
/*!
* \brief Cuurent group ID
*/
int32_t group_id = -1;
int32_t curr_client_id = -1;
std::unordered_map<int32_t, std::unordered_map<int32_t, int32_t>> clients_;
/*! \brief Get the RPC context singleton */ /*! \brief Get the RPC context singleton */
static RPCContext* getInstance() { static RPCContext* getInstance() {
static RPCContext ctx; static RPCContext ctx;
...@@ -116,12 +124,35 @@ struct RPCContext { ...@@ -116,12 +124,35 @@ struct RPCContext {
t->msg_seq = 0; t->msg_seq = 0;
t->num_servers = 0; t->num_servers = 0;
t->num_clients = 0; t->num_clients = 0;
t->barrier_count = 0; t->barrier_count.clear();
t->num_servers_per_machine = 0; t->num_servers_per_machine = 0;
t->sender.reset(); t->sender.reset();
t->receiver.reset(); t->receiver.reset();
t->ctx.reset(); t->ctx.reset();
t->server_state.reset(); t->server_state.reset();
t->group_id = -1;
t->curr_client_id = -1;
t->clients_.clear();
}
int32_t RegisterClient(int32_t client_id, int32_t group_id) {
auto &&m = clients_[group_id];
if (m.find(client_id) != m.end()) {
return -1;
}
m[client_id] = ++curr_client_id;
return curr_client_id;
}
int32_t GetClient(int32_t client_id, int32_t group_id) const {
if (clients_.find(group_id) == clients_.end()) {
return -1;
}
const auto &m = clients_.at(group_id);
if (m.find(client_id) == m.end()) {
return -1;
}
return m.at(client_id);
} }
}; };
......
...@@ -38,6 +38,9 @@ struct RPCMessage : public runtime::Object { ...@@ -38,6 +38,9 @@ struct RPCMessage : public runtime::Object {
/*! \brief Extra payloads in the form of tensors.*/ /*! \brief Extra payloads in the form of tensors.*/
std::vector<runtime::NDArray> tensors; std::vector<runtime::NDArray> tensors;
/*! \brief Group ID. */
int32_t group_id{0};
bool Load(dmlc::Stream* stream) { bool Load(dmlc::Stream* stream) {
stream->Read(&service_id); stream->Read(&service_id);
stream->Read(&msg_seq); stream->Read(&msg_seq);
...@@ -45,6 +48,7 @@ struct RPCMessage : public runtime::Object { ...@@ -45,6 +48,7 @@ struct RPCMessage : public runtime::Object {
stream->Read(&server_id); stream->Read(&server_id);
stream->Read(&data); stream->Read(&data);
stream->Read(&tensors); stream->Read(&tensors);
stream->Read(&group_id);
return true; return true;
} }
...@@ -55,6 +59,7 @@ struct RPCMessage : public runtime::Object { ...@@ -55,6 +59,7 @@ struct RPCMessage : public runtime::Object {
stream->Write(server_id); stream->Write(server_id);
stream->Write(data); stream->Write(data);
stream->Write(tensors); stream->Write(tensors);
stream->Write(group_id);
} }
static constexpr const char* _type_key = "rpc.RPCMessage"; static constexpr const char* _type_key = "rpc.RPCMessage";
......
...@@ -28,11 +28,11 @@ def create_random_graph(n): ...@@ -28,11 +28,11 @@ def create_random_graph(n):
arr = (spsp.random(n, n, density=0.001, format='coo', random_state=100) != 0).astype(np.int64) arr = (spsp.random(n, n, density=0.001, format='coo', random_state=100) != 0).astype(np.int64)
return dgl.from_scipy(arr) return dgl.from_scipy(arr)
def run_server(graph_name, server_id, server_count, num_clients, shared_mem): def run_server(graph_name, server_id, server_count, num_clients, shared_mem, keep_alive=False):
g = DistGraphServer(server_id, "kv_ip_config.txt", server_count, num_clients, g = DistGraphServer(server_id, "kv_ip_config.txt", server_count, num_clients,
'/tmp/dist_graph/{}.json'.format(graph_name), '/tmp/dist_graph/{}.json'.format(graph_name),
disable_shared_mem=not shared_mem, disable_shared_mem=not shared_mem,
graph_format=['csc', 'coo']) graph_format=['csc', 'coo'], keep_alive=keep_alive)
print('start server', server_id) print('start server', server_id)
g.start() g.start()
...@@ -114,16 +114,18 @@ def check_server_client_empty(shared_mem, num_servers, num_clients): ...@@ -114,16 +114,18 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
print('clients have terminated') print('clients have terminated')
def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges): def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges, group_id):
os.environ['DGL_NUM_SERVER'] = str(server_count) os.environ['DGL_NUM_SERVER'] = str(server_count)
os.environ['DGL_GROUP_ID'] = str(group_id)
dgl.distributed.initialize("kv_ip_config.txt") dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None) part_id, None)
g = DistGraph(graph_name, gpb=gpb) g = DistGraph(graph_name, gpb=gpb)
check_dist_graph(g, num_clients, num_nodes, num_edges) check_dist_graph(g, num_clients, num_nodes, num_edges)
def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges): def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges, group_id):
os.environ['DGL_NUM_SERVER'] = str(server_count) os.environ['DGL_NUM_SERVER'] = str(server_count)
os.environ['DGL_GROUP_ID'] = str(group_id)
dgl.distributed.initialize("kv_ip_config.txt") dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None) part_id, None)
...@@ -278,13 +280,13 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges): ...@@ -278,13 +280,13 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
print('end') print('end')
def check_dist_emb_server_client(shared_mem, num_servers, num_clients): def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_groups=1):
prepare_dist() prepare_dist()
g = create_random_graph(10000) g = create_random_graph(10000)
# Partition the graph # Partition the graph
num_parts = 1 num_parts = 1
graph_name = 'dist_graph_test_2' graph_name = f'check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}'
g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1) g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1) g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph') partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
...@@ -293,18 +295,21 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients): ...@@ -293,18 +295,21 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients):
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
serv_ps = [] serv_ps = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
keep_alive = num_groups > 1
for serv_id in range(num_servers): for serv_id in range(num_servers):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers, p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
num_clients, shared_mem)) num_clients, shared_mem, keep_alive))
serv_ps.append(p) serv_ps.append(p)
p.start() p.start()
cli_ps = [] cli_ps = []
for cli_id in range(num_clients): for cli_id in range(num_clients):
print('start client', cli_id) for group_id in range(num_groups):
print('start client[{}] for group[{}]'.format(cli_id, group_id))
p = ctx.Process(target=run_emb_client, args=(graph_name, 0, num_servers, num_clients, p = ctx.Process(target=run_emb_client, args=(graph_name, 0, num_servers, num_clients,
g.number_of_nodes(), g.number_of_nodes(),
g.number_of_edges())) g.number_of_edges(),
group_id))
p.start() p.start()
cli_ps.append(p) cli_ps.append(p)
...@@ -312,18 +317,23 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients): ...@@ -312,18 +317,23 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients):
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
if keep_alive:
for p in serv_ps:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
for p in serv_ps: for p in serv_ps:
p.join() p.join()
print('clients have terminated') print('clients have terminated')
def check_server_client(shared_mem, num_servers, num_clients): def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
prepare_dist() prepare_dist()
g = create_random_graph(10000) g = create_random_graph(10000)
# Partition the graph # Partition the graph
num_parts = 1 num_parts = 1
graph_name = 'dist_graph_test_2' graph_name = f'check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}'
g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1) g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1) g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph') partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
...@@ -332,23 +342,30 @@ def check_server_client(shared_mem, num_servers, num_clients): ...@@ -332,23 +342,30 @@ def check_server_client(shared_mem, num_servers, num_clients):
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
serv_ps = [] serv_ps = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
keep_alive = num_groups > 1
for serv_id in range(num_servers): for serv_id in range(num_servers):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers, p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
num_clients, shared_mem)) num_clients, shared_mem, keep_alive))
serv_ps.append(p) serv_ps.append(p)
p.start() p.start()
# launch different client groups simultaneously
cli_ps = [] cli_ps = []
for cli_id in range(num_clients): for cli_id in range(num_clients):
print('start client', cli_id) for group_id in range(num_groups):
print('start client[{}] for group[{}]'.format(cli_id, group_id))
p = ctx.Process(target=run_client, args=(graph_name, 0, num_servers, num_clients, g.number_of_nodes(), p = ctx.Process(target=run_client, args=(graph_name, 0, num_servers, num_clients, g.number_of_nodes(),
g.number_of_edges())) g.number_of_edges(), group_id))
p.start() p.start()
cli_ps.append(p) cli_ps.append(p)
for p in cli_ps: for p in cli_ps:
p.join() p.join()
if keep_alive:
for p in serv_ps:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
for p in serv_ps: for p in serv_ps:
p.join() p.join()
...@@ -567,6 +584,9 @@ def test_server_client(): ...@@ -567,6 +584,9 @@ def test_server_client():
check_server_client(True, 1, 1) check_server_client(True, 1, 1)
check_server_client(False, 1, 1) check_server_client(False, 1, 1)
check_server_client(True, 2, 2) check_server_client(True, 2, 2)
check_server_client(True, 1, 1, 5)
check_server_client(False, 1, 1, 5)
check_server_client(True, 2, 2, 5)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding")
...@@ -577,6 +597,9 @@ def test_dist_emb_server_client(): ...@@ -577,6 +597,9 @@ def test_dist_emb_server_client():
check_dist_emb_server_client(True, 1, 1) check_dist_emb_server_client(True, 1, 1)
check_dist_emb_server_client(False, 1, 1) check_dist_emb_server_client(False, 1, 1)
check_dist_emb_server_client(True, 2, 2) check_dist_emb_server_client(True, 2, 2)
check_dist_emb_server_client(True, 1, 1, 5)
check_dist_emb_server_client(False, 1, 1, 5)
check_dist_emb_server_client(True, 2, 2, 5)
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support") @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
......
...@@ -18,10 +18,11 @@ import random ...@@ -18,10 +18,11 @@ import random
from dgl.distributed import DistGraphServer, DistGraph from dgl.distributed import DistGraphServer, DistGraph
def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format=['csc', 'coo']): def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format=['csc', 'coo'],
keep_alive=False):
g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1, g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem, tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem,
graph_format=graph_format) graph_format=graph_format, keep_alive=keep_alive)
g.start() g.start()
...@@ -39,6 +40,32 @@ def start_sample_client(rank, tmpdir, disable_shared_mem): ...@@ -39,6 +40,32 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
dgl.distributed.exit_client() dgl.distributed.exit_client()
return sampled_graph return sampled_graph
def start_sample_client_shuffle(rank, tmpdir, disable_shared_mem, g, num_servers, group_id=0):
os.environ['DGL_GROUP_ID'] = str(group_id)
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
for i in range(num_servers):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
src, dst = sampled_graph.edges()
src = orig_nid[src]
dst = orig_nid[dst]
assert sampled_graph.number_of_nodes() == g.number_of_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
eids = g.edge_ids(src, dst)
eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None): def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
...@@ -247,7 +274,7 @@ def test_rpc_sampling(): ...@@ -247,7 +274,7 @@ def test_rpc_sampling():
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling(Path(tmpdirname), 2) check_rpc_sampling(Path(tmpdirname), 2)
def check_rpc_sampling_shuffle(tmpdir, num_server): def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
...@@ -260,33 +287,31 @@ def check_rpc_sampling_shuffle(tmpdir, num_server): ...@@ -260,33 +287,31 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
keep_alive = num_groups > 1
for i in range(num_server): for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling')) p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, 'test_sampling', ['csc', 'coo'], keep_alive))
p.start() p.start()
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
sampled_graph = start_sample_client(0, tmpdir, num_server > 1) pclient_list = []
print("Done sampling") num_clients = 1
for client_id in range(num_clients):
for group_id in range(num_groups):
p = ctx.Process(target=start_sample_client_shuffle, args=(client_id, tmpdir, num_server > 1, g, num_server, group_id))
p.start()
pclient_list.append(p)
for p in pclient_list:
p.join()
if keep_alive:
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("rpc_ip_config.txt", 1)
for p in pserver_list: for p in pserver_list:
p.join() p.join()
orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
src, dst = sampled_graph.edges()
src = orig_nid[src]
dst = orig_nid[dst]
assert sampled_graph.number_of_nodes() == g.number_of_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
eids = g.edge_ids(src, dst)
eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes): def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
...@@ -538,6 +563,7 @@ def test_rpc_sampling_shuffle(num_server): ...@@ -538,6 +563,7 @@ def test_rpc_sampling_shuffle(num_server):
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=5)
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
......
...@@ -38,18 +38,19 @@ class NeighborSampler(object): ...@@ -38,18 +38,19 @@ class NeighborSampler(object):
return blocks return blocks
def start_server(rank, tmpdir, disable_shared_mem, num_clients): def start_server(rank, tmpdir, disable_shared_mem, num_clients, keep_alive=False):
import dgl import dgl
print('server: #clients=' + str(num_clients)) print('server: #clients=' + str(num_clients))
g = DistGraphServer(rank, "mp_ip_config.txt", 1, num_clients, g = DistGraphServer(rank, "mp_ip_config.txt", 1, num_clients,
tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem, tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem,
graph_format=['csc', 'coo']) graph_format=['csc', 'coo'], keep_alive=keep_alive)
g.start() g.start()
def start_dist_dataloader(rank, tmpdir, num_server, drop_last, orig_nid, orig_eid): def start_dist_dataloader(rank, tmpdir, num_server, drop_last, orig_nid, orig_eid, group_id=0):
import dgl import dgl
import torch as th import torch as th
os.environ['DGL_GROUP_ID'] = str(group_id)
dgl.distributed.initialize("mp_ip_config.txt") dgl.distributed.initialize("mp_ip_config.txt")
gpb = None gpb = None
disable_shared_mem = num_server > 0 disable_shared_mem = num_server > 0
...@@ -120,7 +121,6 @@ def test_standalone(tmpdir): ...@@ -120,7 +121,6 @@ def test_standalone(tmpdir):
start_dist_dataloader(0, tmpdir, 1, True, orig_nid, orig_eid) start_dist_dataloader(0, tmpdir, 1, True, orig_nid, orig_eid)
except Exception as e: except Exception as e:
print(e) print(e)
dgl.distributed.exit_client() # this is needed since there's two test here in one process
def start_dist_neg_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, groundtruth_g): def start_dist_neg_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, groundtruth_g):
import dgl import dgl
...@@ -213,7 +213,8 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers): ...@@ -213,7 +213,8 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers):
@pytest.mark.parametrize("num_workers", [0, 4]) @pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
@pytest.mark.parametrize("reshuffle", [True, False]) @pytest.mark.parametrize("reshuffle", [True, False])
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle): @pytest.mark.parametrize("num_groups", [1, 5])
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle, num_groups):
reset_envs() reset_envs()
generate_ip_config("mp_ip_config.txt", num_server, num_server) generate_ip_config("mp_ip_config.txt", num_server, num_server)
...@@ -228,22 +229,34 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle): ...@@ -228,22 +229,34 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
keep_alive = num_groups > 1
for i in range(num_server): for i in range(num_server):
p = ctx.Process(target=start_server, args=( p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, num_workers+1)) i, tmpdir, num_server > 1, num_workers+1, keep_alive))
p.start() p.start()
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers) os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer = ctx.Process(target=start_dist_dataloader, args=( ptrainer_list = []
0, tmpdir, num_server, drop_last, orig_nid, orig_eid)) num_trainers = 1
ptrainer.start() for trainer_id in range(num_trainers):
for group_id in range(num_groups):
p = ctx.Process(target=start_dist_dataloader, args=(
trainer_id, tmpdir, num_server, drop_last, orig_nid, orig_eid, group_id))
p.start()
ptrainer_list.append(p)
for p in ptrainer_list:
p.join()
if keep_alive:
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("mp_ip_config.txt", 1)
for p in pserver_list: for p in pserver_list:
p.join() p.join()
ptrainer.join()
def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_eid, groundtruth_g): def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_eid, groundtruth_g):
import dgl import dgl
...@@ -438,7 +451,8 @@ if __name__ == "__main__": ...@@ -438,7 +451,8 @@ if __name__ == "__main__":
test_dataloader(Path(tmpdirname), 3, 4, 'node') test_dataloader(Path(tmpdirname), 3, 4, 'node')
test_dataloader(Path(tmpdirname), 3, 4, 'edge') test_dataloader(Path(tmpdirname), 3, 4, 'edge')
test_neg_dataloader(Path(tmpdirname), 3, 4) test_neg_dataloader(Path(tmpdirname), 3, 4)
test_dist_dataloader(Path(tmpdirname), 3, 0, True, True) for num_groups in [1, 5]:
test_dist_dataloader(Path(tmpdirname), 3, 4, True, True) test_dist_dataloader(Path(tmpdirname), 3, 0, True, True, num_groups)
test_dist_dataloader(Path(tmpdirname), 3, 0, True, False) test_dist_dataloader(Path(tmpdirname), 3, 4, True, True, num_groups)
test_dist_dataloader(Path(tmpdirname), 3, 4, True, False) test_dist_dataloader(Path(tmpdirname), 3, 0, True, False, num_groups)
test_dist_dataloader(Path(tmpdirname), 3, 4, True, False, num_groups)
...@@ -83,21 +83,24 @@ class HelloRequest(dgl.distributed.Request): ...@@ -83,21 +83,24 @@ class HelloRequest(dgl.distributed.Request):
res = HelloResponse(self.hello_str, self.integer, new_tensor) res = HelloResponse(self.hello_str, self.integer, new_tensor)
return res return res
def start_server(num_clients, ip_config, server_id=0): def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1):
print("Sleep 2 seconds to test client re-connect.") print("Sleep 1 seconds to test client re-connect.")
time.sleep(2) time.sleep(1)
server_state = dgl.distributed.ServerState(None, local_g=None, partition_book=None) server_state = dgl.distributed.ServerState(
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) None, local_g=None, partition_book=None, keep_alive=keep_alive)
dgl.distributed.register_service(
HELLO_SERVICE_ID, HelloRequest, HelloResponse)
print("Start server {}".format(server_id)) print("Start server {}".format(server_id))
dgl.distributed.start_server(server_id=server_id, dgl.distributed.start_server(server_id=server_id,
ip_config=ip_config, ip_config=ip_config,
num_servers=1, num_servers=num_servers,
num_clients=num_clients, num_clients=num_clients,
server_state=server_state) server_state=server_state)
def start_client(ip_config): def start_client(ip_config, group_id=0, num_servers=1):
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.connect_to_server(ip_config=ip_config, num_servers=1) dgl.distributed.connect_to_server(
ip_config=ip_config, num_servers=num_servers, group_id=group_id)
req = HelloRequest(STR, INTEGER, TENSOR, simple_func) req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
# test send and recv # test send and recv
dgl.distributed.send_request(0, req) dgl.distributed.send_request(0, req)
...@@ -238,6 +241,40 @@ def test_multi_thread_rpc(): ...@@ -238,6 +241,40 @@ def test_multi_thread_rpc():
pserver.join() pserver.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client_groups():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = "rpc_ip_config_mul_client_groups.txt"
num_machines = 5
# should test with larger number but due to possible port in-use issue.
num_servers = 1
generate_ip_config(ip_config, num_machines, num_servers)
# presssue test
num_clients = 15
num_groups = 15
ctx = mp.get_context('spawn')
pserver_list = []
for i in range(num_servers*num_machines):
pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, i, True, num_servers))
pserver.start()
pserver_list.append(pserver)
pclient_list = []
for i in range(num_clients):
for group_id in range(num_groups):
pclient = ctx.Process(target=start_client, args=(ip_config, group_id, num_servers))
pclient.start()
pclient_list.append(pclient)
for p in pclient_list:
p.join()
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers(ip_config, num_servers)
for p in pserver_list:
p.join()
if __name__ == '__main__': if __name__ == '__main__':
test_serialize() test_serialize()
test_rpc_msg() test_rpc_msg()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment