Unverified Commit 02e4cd8b authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] long live server for multiple client groups (#3645)

* [Feature] long live server for multiple client groups

* generate globally unique name for DistTensor within DGL automatically
parent 2b98e764
......@@ -23,7 +23,7 @@ from . import optim
from .rpc import *
from .rpc_server import start_server
from .rpc_client import connect_to_server
from .rpc_client import connect_to_server, shutdown_servers
from .dist_context import initialize, exit_client
from .kvstore import KVServer, KVClient
from .server_state import ServerState
......
......@@ -2,3 +2,6 @@
# Maximum size of message queue in bytes
MAX_QUEUE_SIZE = 20*1024*1024*1024
SERVER_EXIT = "server_exit"
SERVER_KEEP_ALIVE = "server_keep_alive"
......@@ -8,12 +8,13 @@ import time
import os
import sys
import queue
import gc
from enum import Enum
from . import rpc
from .constants import MAX_QUEUE_SIZE
from .kvstore import init_kvstore, close_kvstore
from .rpc_client import connect_to_server, shutdown_servers
from .rpc_client import connect_to_server
from .role import init_role
from .. import utils
......@@ -33,13 +34,13 @@ def get_sampler_pool():
return SAMPLER_POOL, NUM_SAMPLER_WORKERS
def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads):
def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads, group_id):
''' This init function is called in the worker processes.
'''
try:
utils.set_num_threads(num_threads)
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
connect_to_server(ip_config, num_servers, max_queue_size, net_type)
connect_to_server(ip_config, num_servers, max_queue_size, net_type, group_id)
init_role(role)
init_kvstore(ip_config, num_servers, role)
except Exception as e:
......@@ -227,12 +228,14 @@ def initialize(ip_config, num_servers=1, num_workers=0,
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats]
rpc.reset()
keep_alive = os.environ.get('DGL_KEEP_ALIVE') is not None
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'),
graph_format=formats)
graph_format=formats,
keep_alive=keep_alive)
serv.start()
sys.exit()
else:
......@@ -244,7 +247,7 @@ def initialize(ip_config, num_servers=1, num_workers=0,
num_servers = int(os.environ.get('DGL_NUM_SERVER'))
else:
num_servers = 1
group_id = int(os.environ.get('DGL_GROUP_ID', 0))
rpc.reset()
global SAMPLER_POOL
global NUM_SAMPLER_WORKERS
......@@ -252,14 +255,15 @@ def initialize(ip_config, num_servers=1, num_workers=0,
'DGL_DIST_MODE', 'standalone') == 'standalone'
if num_workers > 0 and not is_standalone:
SAMPLER_POOL = CustomPool(num_workers, (ip_config, num_servers, max_queue_size,
net_type, 'sampler', num_worker_threads))
net_type, 'sampler', num_worker_threads,
group_id))
else:
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers
if not is_standalone:
assert num_servers is not None and num_servers > 0, \
'The number of servers per machine must be specified with a positive number.'
connect_to_server(ip_config, num_servers, max_queue_size, net_type)
connect_to_server(ip_config, num_servers, max_queue_size, net_type, group_id=group_id)
init_role('default')
init_kvstore(ip_config, num_servers, 'default')
......@@ -299,6 +303,14 @@ def is_initialized():
return INITIALIZED
def _shutdown_servers():
set_initialized(False)
# send ShutDownRequest to servers
if rpc.get_rank() == 0: # Only client_0 issue this command
req = rpc.ShutDownRequest(rpc.get_rank())
for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req)
def exit_client():
"""Trainer exits
......@@ -311,9 +323,11 @@ def exit_client():
"""
# Only client with rank_0 will send shutdown request to servers.
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
# collect data such as DistTensor before exit
gc.collect()
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
rpc.client_barrier()
shutdown_servers()
_shutdown_servers()
finalize_client()
join_finalize_worker()
close_kvstore()
......
......@@ -191,7 +191,7 @@ class NodeDataView(MutableMapping):
dtype, shape, _ = g._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore.
self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
part_policy=policy)
part_policy=policy, attach=False)
def _get_names(self):
return list(self._data.keys())
......@@ -245,7 +245,7 @@ class EdgeDataView(MutableMapping):
dtype, shape, _ = g._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore.
self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
part_policy=policy)
part_policy=policy, attach=False)
def _get_names(self):
return list(self._data.keys())
......@@ -308,16 +308,19 @@ class DistGraphServer(KVServer):
Disable shared memory.
graph_format : str or list of str
The graph formats.
keep_alive : bool
Whether to keep server alive when clients exit
'''
def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False,
graph_format=('csc', 'coo')):
graph_format=('csc', 'coo'), keep_alive=False):
super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config,
num_servers=num_servers,
num_clients=num_clients)
self.ip_config = ip_config
self.num_servers = num_servers
self.keep_alive = keep_alive
# Load graph partition data.
if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards.
......@@ -351,6 +354,7 @@ class DistGraphServer(KVServer):
data_name = HeteroDataName(True, ntype, feat_name)
self.init_data(name=str(data_name), policy_str=data_name.policy_str,
data_tensor=node_feats[name])
self.orig_data.add(str(data_name))
for name in edge_feats:
# The feature name has the following format: edge_type + "/" + feature_name to avoid
# feature name collision for different edge types.
......@@ -358,13 +362,16 @@ class DistGraphServer(KVServer):
data_name = HeteroDataName(False, etype, feat_name)
self.init_data(name=str(data_name), policy_str=data_name.policy_str,
data_tensor=edge_feats[name])
self.orig_data.add(str(data_name))
def start(self):
""" Start graph store server.
"""
# start server
server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb)
print('start graph service on server {} for part {}'.format(self.server_id, self.part_id))
server_state = ServerState(kv_store=self, local_g=self.client_g,
partition_book=self.gpb, keep_alive=self.keep_alive)
print('start graph service on server {} for part {}'.format(
self.server_id, self.part_id))
start_server(server_id=self.server_id,
ip_config=self.ip_config,
num_servers=self.num_servers,
......
......@@ -7,6 +7,7 @@ from .kvstore import get_kvstore
from .role import get_role
from .. import utils
from .. import backend as F
from .rpc import get_group_id
def _default_init_data(shape, dtype):
return F.zeros(shape, dtype, F.cpu())
......@@ -80,6 +81,8 @@ class DistTensor:
Whether the created tensor lives after the ``DistTensor`` object is destroyed.
is_gdata : bool
Whether the created tensor is a ndata/edata or not.
attach : bool
Whether to attach group ID into name to be globally unique.
Examples
--------
......@@ -102,12 +105,13 @@ class DistTensor:
do the same.
'''
def __init__(self, shape, dtype, name=None, init_func=None, part_policy=None,
persistent=False, is_gdata=True):
persistent=False, is_gdata=True, attach=True):
self.kvstore = get_kvstore()
assert self.kvstore is not None, \
'Distributed module is not initialized. Please call dgl.distributed.initialize.'
self._shape = shape
self._dtype = dtype
self._attach = attach
part_policies = self.kvstore.all_possible_part_policy
# If a user doesn't provide a partition policy, we should find one based on
......@@ -128,7 +132,6 @@ class DistTensor:
+ 'its first dimension does not match the number of nodes or edges ' \
+ 'of a distributed graph or there does not exist a distributed graph.'
self._tensor_name = name
self._part_policy = part_policy
assert part_policy.get_size() == shape[0], \
'The partition policy does not match the input shape.'
......@@ -146,6 +149,8 @@ class DistTensor:
name = 'anonymous-' + get_role() + '-' + str(DIST_TENSOR_ID)
DIST_TENSOR_ID += 1
assert isinstance(name, str), 'name {} is type {}'.format(name, type(name))
name = self._attach_group_id(name)
self._tensor_name = name
data_name = part_policy.get_data_name(name)
self._name = str(data_name)
self._persistent = persistent
......@@ -220,7 +225,7 @@ class DistTensor:
str
The name of the tensor.
'''
return self._name
return self._detach_group_id(self._name)
@property
def tensor_name(self):
......@@ -231,7 +236,7 @@ class DistTensor:
str
The name of the tensor.
'''
return self._tensor_name
return self._detach_group_id(self._tensor_name)
def count_nonzero(self):
'''Count and return the number of nonzero value
......@@ -241,4 +246,29 @@ class DistTensor:
int
the number of nonzero value
'''
return self.kvstore.count_nonzero(name=self.name)
return self.kvstore.count_nonzero(name=self._name)
def _attach_group_id(self, name):
"""Attach group ID if needed
Returns
-------
str
new name with group ID attached
"""
if not self._attach:
return name
return "{}_{}".format(name, get_group_id())
def _detach_group_id(self, name):
"""Detach group ID if needed
Returns
-------
str
original name without group ID
"""
if not self._attach:
return name
suffix = "_{}".format(get_group_id())
return name[:-len(suffix)]
......@@ -206,20 +206,23 @@ class BarrierRequest(rpc.Request):
"""
def __init__(self, role):
self.role = role
self.group_id = rpc.get_group_id()
def __getstate__(self):
return self.role
return self.role, self.group_id
def __setstate__(self, state):
self.role = state
self.role, self.group_id = state
def process_request(self, server_state):
kv_store = server_state.kv_store
role = server_state.roles
count = kv_store.barrier_count[self.role]
kv_store.barrier_count[self.role] = count + 1
if kv_store.barrier_count[self.role] == len(role[self.role]):
kv_store.barrier_count[self.role] = 0
roles = server_state.roles
role = roles[self.group_id]
barrier_count = kv_store.barrier_count[self.group_id]
count = barrier_count[self.role]
barrier_count[self.role] = count + 1
if barrier_count[self.role] == len(role[self.role]):
barrier_count[self.role] = 0
res_list = []
for client_id, _ in role[self.role]:
res_list.append((client_id, BarrierResponse(BARRIER_MSG)))
......@@ -362,6 +365,9 @@ class GetSharedDataRequest(rpc.Request):
meta = {}
kv_store = server_state.kv_store
for name, data in kv_store.data_store.items():
if server_state.keep_alive:
if name not in kv_store.orig_data:
continue
meta[name] = (F.shape(data),
F.reverse_data_type_dict[F.dtype(data)],
kv_store.part_policy[name].policy_str)
......@@ -671,6 +677,8 @@ class KVServer(object):
CountLocalNonzeroResponse)
# Store the tensor data with specified data name
self._data_store = {}
# Store original tensor data names when instantiating DistGraphServer
self._orig_data = set()
# Store the partition information with specified data name
self._policy_set = set()
self._part_policy = {}
......@@ -715,6 +723,11 @@ class KVServer(object):
"""Get data store"""
return self._data_store
@property
def orig_data(self):
"""Get original data"""
return self._orig_data
@property
def part_policy(self):
"""Get part policy"""
......
......@@ -39,20 +39,22 @@ class RegisterRoleRequest(rpc.Request):
self.client_id = client_id
self.machine_id = machine_id
self.role = role
self.group_id = rpc.get_group_id()
def __getstate__(self):
return self.client_id, self.machine_id, self.role
return self.client_id, self.machine_id, self.role, self.group_id
def __setstate__(self, state):
self.client_id, self.machine_id, self.role = state
self.client_id, self.machine_id, self.role, self.group_id = state
def process_request(self, server_state):
kv_store = server_state.kv_store
role = server_state.roles
role = server_state.roles.setdefault(self.group_id, {})
if self.role not in role:
role[self.role] = set()
if kv_store is not None:
kv_store.barrier_count[self.role] = 0
barrier_count = kv_store.barrier_count.setdefault(self.group_id, {})
barrier_count[self.role] = 0
role[self.role].add((self.client_id, self.machine_id))
total_count = 0
for key in role:
......@@ -84,15 +86,16 @@ class GetRoleRequest(rpc.Request):
"""Send a request to get the roles of all client processes."""
def __init__(self):
self.msg = GET_ROLE_MSG
self.group_id = rpc.get_group_id()
def __getstate__(self):
return self.msg
return self.msg, self.group_id
def __setstate__(self, state):
self.msg = state
self.msg, self.group_id = state
def process_request(self, server_state):
return GetRoleResponse(server_state.roles)
return GetRoleResponse(server_state.roles[self.group_id])
# The key is role, the value is a dict of mapping RPC rank to a rank within the role.
PER_ROLE_RANK = {}
......
......@@ -6,6 +6,8 @@ import pickle
import random
import numpy as np
from .constants import SERVER_EXIT, SERVER_KEEP_ALIVE
from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api
from ..base import DGLError
......@@ -156,7 +158,7 @@ def receiver_wait(ip_addr, port, num_senders, blocking=True):
"""
_CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders), blocking)
def connect_receiver(ip_addr, port, recv_id):
def connect_receiver(ip_addr, port, recv_id, group_id=-1):
"""Connect to target receiver
Parameters
......@@ -168,7 +170,10 @@ def connect_receiver(ip_addr, port, recv_id):
recv_id : int
receiver's ID
"""
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(recv_id))
target_id = recv_id if group_id == -1 else register_client(recv_id, group_id)
if target_id < 0:
raise DGLError("Invalid target id: {}".format(target_id))
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id))
def set_rank(rank):
"""Set the rank of this process.
......@@ -497,8 +502,10 @@ class RPCMessage(ObjectBase):
Payload buffer carried by this request.
tensors : list[tensor]
Extra payloads in the form of tensors.
group_id : int
The group ID
"""
def __init__(self, service_id, msg_seq, client_id, server_id, data, tensors):
def __init__(self, service_id, msg_seq, client_id, server_id, data, tensors, group_id=0):
self.__init_handle_by_constructor__(
_CAPI_DGLRPCCreateRPCMessage,
int(service_id),
......@@ -506,7 +513,8 @@ class RPCMessage(ObjectBase):
int(client_id),
int(server_id),
data,
[F.zerocopy_to_dgl_ndarray(tsor) for tsor in tensors])
[F.zerocopy_to_dgl_ndarray(tsor) for tsor in tensors],
int(group_id))
@property
def service_id(self):
......@@ -539,6 +547,11 @@ class RPCMessage(ObjectBase):
rst = _CAPI_DGLRPCMessageGetTensors(self)
return [F.zerocopy_from_dgl_ndarray(tsor) for tsor in rst]
@property
def group_id(self):
"""Get group ID."""
return _CAPI_DGLRPCMessageGetGroupId(self)
def send_request(target, request):
"""Send one request to the target server.
......@@ -566,7 +579,8 @@ def send_request(target, request):
client_id = get_rank()
server_id = target
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
msg = RPCMessage(service_id, msg_seq, client_id, server_id,
data, tensors, group_id=get_group_id())
send_rpc_message(msg, server_id)
def send_request_to_machine(target, request):
......@@ -595,10 +609,10 @@ def send_request_to_machine(target, request):
server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1)
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id())
send_rpc_message(msg, server_id)
def send_response(target, response):
def send_response(target, response, group_id):
"""Send one response to the target client.
Serialize the given response object to an :class:`RPCMessage` and send it
......@@ -615,6 +629,8 @@ def send_response(target, response):
ID of target client.
response : Response
The response to send.
group_id : int
Group ID of target client.
Raises
------
......@@ -625,8 +641,8 @@ def send_response(target, response):
client_id = target
server_id = get_rank()
data, tensors = serialize_to_payload(response)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
send_rpc_message(msg, client_id)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, group_id)
send_rpc_message(msg, get_client(client_id, group_id))
def recv_request(timeout=0):
"""Receive one request.
......@@ -647,6 +663,8 @@ def recv_request(timeout=0):
One request received from the target, or None if it times out.
client_id : int
Client' ID received from the target.
group_id : int
Group' ID received from the target.
Raises
------
......@@ -665,7 +683,7 @@ def recv_request(timeout=0):
if msg.server_id != get_rank():
raise DGLError('Got request sent to server {}, '
'different from my rank {}!'.format(msg.server_id, get_rank()))
return req, msg.client_id
return req, msg.client_id, msg.group_id
def recv_response(timeout=0):
"""Receive one response.
......@@ -699,8 +717,11 @@ def recv_response(timeout=0):
'but no response class is registered.'.format(msg.service_id))
res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != get_rank() and get_rank() != -1:
raise DGLError('Got reponse of request sent by client {}, '
raise DGLError('Got response of request sent by client {}, '
'different from my rank {}!'.format(msg.client_id, get_rank()))
if msg.group_id != get_group_id():
raise DGLError("Got response of request sent by group {}, "
"different from my group {}!".format(msg.group_id, get_group_id()))
return res
def remote_call(target_and_requests, timeout=0):
......@@ -742,7 +763,7 @@ def remote_call(target_and_requests, timeout=0):
server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1)
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id())
send_rpc_message(msg, server_id)
# check if has response
res_cls = get_service_property(service_id)[1]
......@@ -792,7 +813,7 @@ def send_requests_to_machine(target_and_requests):
server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1)
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id())
send_rpc_message(msg, server_id)
# check if has response
res_cls = get_service_property(service_id)[1]
......@@ -1050,19 +1071,22 @@ class ShutDownRequest(Request):
client_id : int
client's ID
"""
def __init__(self, client_id):
def __init__(self, client_id, force_shutdown_server=False):
self.client_id = client_id
self.force_shutdown_server = force_shutdown_server
def __getstate__(self):
return self.client_id
return self.client_id, self.force_shutdown_server
def __setstate__(self, state):
self.client_id = state
self.client_id, self.force_shutdown_server = state
def process_request(self, server_state):
assert self.client_id == 0
if server_state.keep_alive and not self.force_shutdown_server:
return SERVER_KEEP_ALIVE
finalize_server()
return 'exit'
return SERVER_EXIT
GET_NUM_CLIENT = 22453
......@@ -1133,21 +1157,69 @@ class ClientBarrierRequest(Request):
"""
def __init__(self, msg='barrier'):
self.msg = msg
self.group_id = get_group_id()
def __getstate__(self):
return self.msg
return self.msg, self.group_id
def __setstate__(self, state):
self.msg = state
self.msg, self.group_id = state
def process_request(self, server_state):
_CAPI_DGLRPCSetBarrierCount(_CAPI_DGLRPCGetBarrierCount()+1)
if _CAPI_DGLRPCGetBarrierCount() == get_num_client():
_CAPI_DGLRPCSetBarrierCount(0)
_CAPI_DGLRPCSetBarrierCount(_CAPI_DGLRPCGetBarrierCount(self.group_id)+1, self.group_id)
if _CAPI_DGLRPCGetBarrierCount(self.group_id) == get_num_client():
_CAPI_DGLRPCSetBarrierCount(0, self.group_id)
res_list = []
for target_id in range(get_num_client()):
res_list.append((target_id, ClientBarrierResponse()))
return res_list
return None
def set_group_id(group_id):
"""Set current group ID
Parameters
----------
group_id : int
Current group ID
"""
_CAPI_DGLRPCSetGroupID(int(group_id))
def get_group_id():
"""Get current group ID
Returns
-------
int
group ID
"""
return _CAPI_DGLRPCGetGroupID()
def register_client(client_id, group_id):
"""Register client
Returns
-------
int
unique client ID
"""
return _CAPI_DGLRPCRegisterClient(int(client_id), int(group_id))
def get_client(client_id, group_id):
"""Get global client ID
Parameters
----------
client_id : int
client ID
group_id : int
group ID
Returns
-------
int
global client ID
"""
return _CAPI_DGLRPCGetClient(int(client_id), int(group_id))
_init_api("dgl.distributed.rpc")
......@@ -103,7 +103,8 @@ def get_local_usable_addr(probe_addr):
return ip_addr + ':' + str(port)
def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
net_type='socket', group_id=0):
"""Connect this client to server.
Parameters
......@@ -118,6 +119,10 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: 'socket'.
group_id : int
Indicates which group this client belongs to. Clients that are
booted together in each launch are gathered as a group and should
have same unique group_id.
Raises
------
......@@ -156,6 +161,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
rpc.set_num_machines(num_machines)
machine_id = get_local_machine_id(server_namebook)
rpc.set_machine_id(machine_id)
rpc.set_group_id(group_id)
rpc.create_sender(max_queue_size, net_type)
rpc.create_receiver(max_queue_size, net_type)
# Get connected with all server nodes
......@@ -169,6 +175,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
client_ip, client_port = ip_addr.split(':')
# wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers, blocking=False)
print("Client [{}] waits on {}:{}".format(os.getpid(), client_ip, client_port))
# Register client on server
register_req = rpc.ClientRegisterRequest(ip_addr)
for server_id in range(num_servers):
......@@ -176,8 +183,8 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
# recv client ID from server
res = rpc.recv_response()
rpc.set_rank(res.client_id)
print("Machine (%d) client (%d) connect to server successfuly!" \
% (machine_id, rpc.get_rank()))
print("Machine (%d) group (%d) client (%d) connect to server successfuly!" \
% (machine_id, group_id, rpc.get_rank()))
# get total number of client
get_client_num_req = rpc.GetNumberClientsRequest(rpc.get_rank())
rpc.send_request(0, get_client_num_req)
......@@ -187,16 +194,44 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
atexit.register(exit_client)
set_initialized(True)
def shutdown_servers():
def shutdown_servers(ip_config, num_servers):
"""Issue commands to remote servers to shut them down.
This function is required to be called manually only when we
have booted servers which keep alive even clients exit. In
order to shut down server elegantly, we utilize existing
client logic/code to boot a special client which does nothing
but send shut down request to servers. Once such request is
received, servers will exit from endless wait loop, release
occupied resources and end its process. Please call this function
with same arguments used in `dgl.distributed.connect_to_server`.
Parameters
----------
ip_config : str
Path of server IP configuration file.
num_servers : int
server count on each machine.
Raises
------
ConnectionError : If anything wrong with the connection.
"""
from .dist_context import set_initialized
set_initialized(False)
if rpc.get_rank() == 0: # Only client_0 issue this command
req = rpc.ShutDownRequest(rpc.get_rank())
for server_id in range(rpc.get_num_server()):
rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest,
None)
rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook)
rpc.create_sender(MAX_QUEUE_SIZE, 'socket')
# Get connected with all server nodes
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
while not rpc.connect_receiver(server_ip, server_port, server_id):
time.sleep(1)
# send ShutDownRequest to all servers
req = rpc.ShutDownRequest(0, True)
for server_id in range(num_servers):
rpc.send_request(server_id, req)
rpc.finalize_sender()
"""Functions used by server."""
import time
from ..base import DGLError
from . import rpc
from .constants import MAX_QUEUE_SIZE
from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE
def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
......@@ -37,6 +38,9 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_clients
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type
if server_state.keep_alive:
print("As configured, this server will keep alive for multiple"
" client groups until force shutdown request is received.")
# Register signal handler.
rpc.register_sig_handler()
# Register some basic services
......@@ -63,39 +67,52 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# wait all the senders connect to server.
# Once all the senders connect to server, server will not
# accept new sender's connection
print("Wait connections non-blockingly...")
print(
"Server is waiting for connections non-blockingly on [{}:{}]...".format(ip_addr, port))
rpc.receiver_wait(ip_addr, port, num_clients, blocking=False)
rpc.set_num_client(num_clients)
# Recv all the client's IP and assign ID to clients
addr_list = []
client_namebook = {}
for _ in range(num_clients):
# blocked until request is received
req, _ = rpc.recv_request()
assert isinstance(req, rpc.ClientRegisterRequest)
addr_list.append(req.ip_addr)
addr_list.sort()
for client_id, addr in enumerate(addr_list):
client_namebook[client_id] = addr
recv_clients = {}
while True:
# go through if any client group is ready for connection
for group_id in list(recv_clients.keys()):
ips = recv_clients[group_id]
if len(ips) < rpc.get_num_client():
continue
else:
del recv_clients[group_id]
# a new client group is ready
ips.sort()
client_namebook = {client_id:addr for client_id, addr in enumerate(ips)}
for client_id, addr in client_namebook.items():
client_ip, client_port = addr.split(':')
# TODO[Rhett]: server should not be blocked endlessly.
while not rpc.connect_receiver(client_ip, client_port, client_id):
while not rpc.connect_receiver(client_ip, client_port, client_id, group_id):
time.sleep(1)
if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id)
rpc.send_response(client_id, register_res)
# main service loop
while True:
req, client_id = rpc.recv_request()
rpc.send_response(client_id, register_res, group_id)
# receive incomming client requests
req, client_id, group_id = rpc.recv_request()
if isinstance(req, rpc.ClientRegisterRequest):
if group_id not in recv_clients:
recv_clients[group_id] = []
recv_clients[group_id].append(req.ip_addr)
continue
res = req.process_request(server_state)
if res is not None:
if isinstance(res, list):
for response in res:
target_id, res_data = response
rpc.send_response(target_id, res_data)
elif isinstance(res, str) and res == 'exit':
break # break the loop and exit server
rpc.send_response(target_id, res_data, group_id)
elif isinstance(res, str):
if res == SERVER_EXIT:
print("Server is exiting...")
return
elif res == SERVER_KEEP_ALIVE:
print("Server keeps alive while client group~{} is exiting...".format(group_id))
else:
raise DGLError("Unexpected response: {}".format(res))
else:
rpc.send_response(client_id, res)
rpc.send_response(client_id, res, group_id)
......@@ -38,12 +38,15 @@ class ServerState:
Total number of edges
partition_book : GraphPartitionBook
Graph Partition book
keep_alive : bool
whether to keep alive which supports any number of client groups connect
"""
def __init__(self, kv_store, local_g, partition_book):
def __init__(self, kv_store, local_g, partition_book, keep_alive=False):
self._kv_store = kv_store
self._graph = local_g
self.partition_book = partition_book
self._keep_alive = keep_alive
self._roles = {}
@property
......@@ -69,5 +72,9 @@ class ServerState:
def graph(self, graph):
self._graph = graph
@property
def keep_alive(self):
"""Flag of whether keep alive"""
return self._keep_alive
_init_api("dgl.distributed.server_state")
......@@ -223,13 +223,19 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->barrier_count;
const int32_t group_id = args[0];
auto&& cnt = RPCContext::getInstance()->barrier_count;
if (cnt.find(group_id) == cnt.end()) {
cnt.emplace(group_id, 0x0);
}
*rv = cnt[group_id];
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t count = args[0];
RPCContext::getInstance()->barrier_count = count;
const int32_t group_id = args[1];
RPCContext::getInstance()->barrier_count[group_id] = count;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
......@@ -296,6 +302,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
args[4]; // directly assigning string value raises errors :(
rst->data = data;
rst->tensors = ListValueToVector<NDArray>(args[5]);
rst->group_id = args[6];
*rv = rst;
});
......@@ -464,6 +471,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
msg.data = pickle_data;
NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
msg.tensors.push_back(tensor);
msg.group_id = RPCContext::getInstance()->group_id;
SendRPCMessage(msg, msg.server_id);
msg_count++;
}
......@@ -499,6 +507,37 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
*rv = res_tensor;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->group_id;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t group_id = args[0];
RPCContext::getInstance()->group_id = group_id;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
*rv = msg->group_id;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t client_id = args[0];
const int32_t group_id = args[1];
*rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t client_id = args[0];
const int32_t group_id = args[1];
*rv = RPCContext::getInstance()->GetClient(client_id, group_id);
});
} // namespace rpc
} // namespace dgl
......
......@@ -16,6 +16,7 @@
#include <vector>
#include <string>
#include <mutex>
#include <unordered_map>
#include "./rpc_msg.h"
#include "./tensorpipe/tp_communicator.h"
......@@ -68,7 +69,7 @@ struct RPCContext {
/*!
* \brief Current barrier count
*/
int32_t barrier_count = 0;
std::unordered_map<int32_t, int32_t> barrier_count;
/*!
* \brief Total number of server per machine.
......@@ -101,6 +102,13 @@ struct RPCContext {
*/
std::shared_ptr<ServerState> server_state;
/*!
* \brief Cuurent group ID
*/
int32_t group_id = -1;
int32_t curr_client_id = -1;
std::unordered_map<int32_t, std::unordered_map<int32_t, int32_t>> clients_;
/*! \brief Get the RPC context singleton */
static RPCContext* getInstance() {
static RPCContext ctx;
......@@ -116,12 +124,35 @@ struct RPCContext {
t->msg_seq = 0;
t->num_servers = 0;
t->num_clients = 0;
t->barrier_count = 0;
t->barrier_count.clear();
t->num_servers_per_machine = 0;
t->sender.reset();
t->receiver.reset();
t->ctx.reset();
t->server_state.reset();
t->group_id = -1;
t->curr_client_id = -1;
t->clients_.clear();
}
int32_t RegisterClient(int32_t client_id, int32_t group_id) {
auto &&m = clients_[group_id];
if (m.find(client_id) != m.end()) {
return -1;
}
m[client_id] = ++curr_client_id;
return curr_client_id;
}
int32_t GetClient(int32_t client_id, int32_t group_id) const {
if (clients_.find(group_id) == clients_.end()) {
return -1;
}
const auto &m = clients_.at(group_id);
if (m.find(client_id) == m.end()) {
return -1;
}
return m.at(client_id);
}
};
......
......@@ -38,6 +38,9 @@ struct RPCMessage : public runtime::Object {
/*! \brief Extra payloads in the form of tensors.*/
std::vector<runtime::NDArray> tensors;
/*! \brief Group ID. */
int32_t group_id{0};
bool Load(dmlc::Stream* stream) {
stream->Read(&service_id);
stream->Read(&msg_seq);
......@@ -45,6 +48,7 @@ struct RPCMessage : public runtime::Object {
stream->Read(&server_id);
stream->Read(&data);
stream->Read(&tensors);
stream->Read(&group_id);
return true;
}
......@@ -55,6 +59,7 @@ struct RPCMessage : public runtime::Object {
stream->Write(server_id);
stream->Write(data);
stream->Write(tensors);
stream->Write(group_id);
}
static constexpr const char* _type_key = "rpc.RPCMessage";
......
......@@ -28,11 +28,11 @@ def create_random_graph(n):
arr = (spsp.random(n, n, density=0.001, format='coo', random_state=100) != 0).astype(np.int64)
return dgl.from_scipy(arr)
def run_server(graph_name, server_id, server_count, num_clients, shared_mem):
def run_server(graph_name, server_id, server_count, num_clients, shared_mem, keep_alive=False):
g = DistGraphServer(server_id, "kv_ip_config.txt", server_count, num_clients,
'/tmp/dist_graph/{}.json'.format(graph_name),
disable_shared_mem=not shared_mem,
graph_format=['csc', 'coo'])
graph_format=['csc', 'coo'], keep_alive=keep_alive)
print('start server', server_id)
g.start()
......@@ -114,16 +114,18 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
print('clients have terminated')
def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges, group_id):
os.environ['DGL_NUM_SERVER'] = str(server_count)
os.environ['DGL_GROUP_ID'] = str(group_id)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
g = DistGraph(graph_name, gpb=gpb)
check_dist_graph(g, num_clients, num_nodes, num_edges)
def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges, group_id):
os.environ['DGL_NUM_SERVER'] = str(server_count)
os.environ['DGL_GROUP_ID'] = str(group_id)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
......@@ -278,13 +280,13 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
print('end')
def check_dist_emb_server_client(shared_mem, num_servers, num_clients):
def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_groups=1):
prepare_dist()
g = create_random_graph(10000)
# Partition the graph
num_parts = 1
graph_name = 'dist_graph_test_2'
graph_name = f'check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}'
g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
......@@ -293,18 +295,21 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients):
# We cannot run multiple servers and clients on the same machine.
serv_ps = []
ctx = mp.get_context('spawn')
keep_alive = num_groups > 1
for serv_id in range(num_servers):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
num_clients, shared_mem))
num_clients, shared_mem, keep_alive))
serv_ps.append(p)
p.start()
cli_ps = []
for cli_id in range(num_clients):
print('start client', cli_id)
for group_id in range(num_groups):
print('start client[{}] for group[{}]'.format(cli_id, group_id))
p = ctx.Process(target=run_emb_client, args=(graph_name, 0, num_servers, num_clients,
g.number_of_nodes(),
g.number_of_edges()))
g.number_of_edges(),
group_id))
p.start()
cli_ps.append(p)
......@@ -312,18 +317,23 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients):
p.join()
assert p.exitcode == 0
if keep_alive:
for p in serv_ps:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
for p in serv_ps:
p.join()
print('clients have terminated')
def check_server_client(shared_mem, num_servers, num_clients):
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
prepare_dist()
g = create_random_graph(10000)
# Partition the graph
num_parts = 1
graph_name = 'dist_graph_test_2'
graph_name = f'check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}'
g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
......@@ -332,23 +342,30 @@ def check_server_client(shared_mem, num_servers, num_clients):
# We cannot run multiple servers and clients on the same machine.
serv_ps = []
ctx = mp.get_context('spawn')
keep_alive = num_groups > 1
for serv_id in range(num_servers):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
num_clients, shared_mem))
num_clients, shared_mem, keep_alive))
serv_ps.append(p)
p.start()
# launch different client groups simultaneously
cli_ps = []
for cli_id in range(num_clients):
print('start client', cli_id)
for group_id in range(num_groups):
print('start client[{}] for group[{}]'.format(cli_id, group_id))
p = ctx.Process(target=run_client, args=(graph_name, 0, num_servers, num_clients, g.number_of_nodes(),
g.number_of_edges()))
g.number_of_edges(), group_id))
p.start()
cli_ps.append(p)
for p in cli_ps:
p.join()
if keep_alive:
for p in serv_ps:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
for p in serv_ps:
p.join()
......@@ -567,6 +584,9 @@ def test_server_client():
check_server_client(True, 1, 1)
check_server_client(False, 1, 1)
check_server_client(True, 2, 2)
check_server_client(True, 1, 1, 5)
check_server_client(False, 1, 1, 5)
check_server_client(True, 2, 2, 5)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding")
......@@ -577,6 +597,9 @@ def test_dist_emb_server_client():
check_dist_emb_server_client(True, 1, 1)
check_dist_emb_server_client(False, 1, 1)
check_dist_emb_server_client(True, 2, 2)
check_dist_emb_server_client(True, 1, 1, 5)
check_dist_emb_server_client(False, 1, 1, 5)
check_dist_emb_server_client(True, 2, 2, 5)
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
......
......@@ -18,10 +18,11 @@ import random
from dgl.distributed import DistGraphServer, DistGraph
def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format=['csc', 'coo']):
def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format=['csc', 'coo'],
keep_alive=False):
g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem,
graph_format=graph_format)
graph_format=graph_format, keep_alive=keep_alive)
g.start()
......@@ -39,6 +40,32 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
dgl.distributed.exit_client()
return sampled_graph
def start_sample_client_shuffle(rank, tmpdir, disable_shared_mem, g, num_servers, group_id=0):
os.environ['DGL_GROUP_ID'] = str(group_id)
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
for i in range(num_servers):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
src, dst = sampled_graph.edges()
src = orig_nid[src]
dst = orig_nid[dst]
assert sampled_graph.number_of_nodes() == g.number_of_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
eids = g.edge_ids(src, dst)
eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
gpb = None
if disable_shared_mem:
......@@ -247,7 +274,7 @@ def test_rpc_sampling():
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling(Path(tmpdirname), 2)
def check_rpc_sampling_shuffle(tmpdir, num_server):
def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = CitationGraphDataset("cora")[0]
......@@ -260,33 +287,31 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
pserver_list = []
ctx = mp.get_context('spawn')
keep_alive = num_groups > 1
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, 'test_sampling', ['csc', 'coo'], keep_alive))
p.start()
time.sleep(1)
pserver_list.append(p)
sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
print("Done sampling")
pclient_list = []
num_clients = 1
for client_id in range(num_clients):
for group_id in range(num_groups):
p = ctx.Process(target=start_sample_client_shuffle, args=(client_id, tmpdir, num_server > 1, g, num_server, group_id))
p.start()
pclient_list.append(p)
for p in pclient_list:
p.join()
if keep_alive:
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("rpc_ip_config.txt", 1)
for p in pserver_list:
p.join()
orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
src, dst = sampled_graph.edges()
src = orig_nid[src]
dst = orig_nid[dst]
assert sampled_graph.number_of_nodes() == g.number_of_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
eids = g.edge_ids(src, dst)
eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes):
gpb = None
if disable_shared_mem:
......@@ -538,6 +563,7 @@ def test_rpc_sampling_shuffle(num_server):
os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=5)
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
......
......@@ -38,18 +38,19 @@ class NeighborSampler(object):
return blocks
def start_server(rank, tmpdir, disable_shared_mem, num_clients):
def start_server(rank, tmpdir, disable_shared_mem, num_clients, keep_alive=False):
import dgl
print('server: #clients=' + str(num_clients))
g = DistGraphServer(rank, "mp_ip_config.txt", 1, num_clients,
tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem,
graph_format=['csc', 'coo'])
graph_format=['csc', 'coo'], keep_alive=keep_alive)
g.start()
def start_dist_dataloader(rank, tmpdir, num_server, drop_last, orig_nid, orig_eid):
def start_dist_dataloader(rank, tmpdir, num_server, drop_last, orig_nid, orig_eid, group_id=0):
import dgl
import torch as th
os.environ['DGL_GROUP_ID'] = str(group_id)
dgl.distributed.initialize("mp_ip_config.txt")
gpb = None
disable_shared_mem = num_server > 0
......@@ -120,7 +121,6 @@ def test_standalone(tmpdir):
start_dist_dataloader(0, tmpdir, 1, True, orig_nid, orig_eid)
except Exception as e:
print(e)
dgl.distributed.exit_client() # this is needed since there's two test here in one process
def start_dist_neg_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, groundtruth_g):
import dgl
......@@ -213,7 +213,8 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers):
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("drop_last", [True, False])
@pytest.mark.parametrize("reshuffle", [True, False])
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
@pytest.mark.parametrize("num_groups", [1, 5])
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle, num_groups):
reset_envs()
generate_ip_config("mp_ip_config.txt", num_server, num_server)
......@@ -228,22 +229,34 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
pserver_list = []
ctx = mp.get_context('spawn')
keep_alive = num_groups > 1
for i in range(num_server):
p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, num_workers+1))
i, tmpdir, num_server > 1, num_workers+1, keep_alive))
p.start()
time.sleep(1)
pserver_list.append(p)
os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer = ctx.Process(target=start_dist_dataloader, args=(
0, tmpdir, num_server, drop_last, orig_nid, orig_eid))
ptrainer.start()
ptrainer_list = []
num_trainers = 1
for trainer_id in range(num_trainers):
for group_id in range(num_groups):
p = ctx.Process(target=start_dist_dataloader, args=(
trainer_id, tmpdir, num_server, drop_last, orig_nid, orig_eid, group_id))
p.start()
ptrainer_list.append(p)
for p in ptrainer_list:
p.join()
if keep_alive:
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("mp_ip_config.txt", 1)
for p in pserver_list:
p.join()
ptrainer.join()
def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_eid, groundtruth_g):
import dgl
......@@ -438,7 +451,8 @@ if __name__ == "__main__":
test_dataloader(Path(tmpdirname), 3, 4, 'node')
test_dataloader(Path(tmpdirname), 3, 4, 'edge')
test_neg_dataloader(Path(tmpdirname), 3, 4)
test_dist_dataloader(Path(tmpdirname), 3, 0, True, True)
test_dist_dataloader(Path(tmpdirname), 3, 4, True, True)
test_dist_dataloader(Path(tmpdirname), 3, 0, True, False)
test_dist_dataloader(Path(tmpdirname), 3, 4, True, False)
for num_groups in [1, 5]:
test_dist_dataloader(Path(tmpdirname), 3, 0, True, True, num_groups)
test_dist_dataloader(Path(tmpdirname), 3, 4, True, True, num_groups)
test_dist_dataloader(Path(tmpdirname), 3, 0, True, False, num_groups)
test_dist_dataloader(Path(tmpdirname), 3, 4, True, False, num_groups)
......@@ -83,21 +83,24 @@ class HelloRequest(dgl.distributed.Request):
res = HelloResponse(self.hello_str, self.integer, new_tensor)
return res
def start_server(num_clients, ip_config, server_id=0):
print("Sleep 2 seconds to test client re-connect.")
time.sleep(2)
server_state = dgl.distributed.ServerState(None, local_g=None, partition_book=None)
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1):
print("Sleep 1 seconds to test client re-connect.")
time.sleep(1)
server_state = dgl.distributed.ServerState(
None, local_g=None, partition_book=None, keep_alive=keep_alive)
dgl.distributed.register_service(
HELLO_SERVICE_ID, HelloRequest, HelloResponse)
print("Start server {}".format(server_id))
dgl.distributed.start_server(server_id=server_id,
ip_config=ip_config,
num_servers=1,
num_servers=num_servers,
num_clients=num_clients,
server_state=server_state)
def start_client(ip_config):
def start_client(ip_config, group_id=0, num_servers=1):
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.connect_to_server(ip_config=ip_config, num_servers=1)
dgl.distributed.connect_to_server(
ip_config=ip_config, num_servers=num_servers, group_id=group_id)
req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
# test send and recv
dgl.distributed.send_request(0, req)
......@@ -238,6 +241,40 @@ def test_multi_thread_rpc():
pserver.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client_groups():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = "rpc_ip_config_mul_client_groups.txt"
num_machines = 5
# should test with larger number but due to possible port in-use issue.
num_servers = 1
generate_ip_config(ip_config, num_machines, num_servers)
# presssue test
num_clients = 15
num_groups = 15
ctx = mp.get_context('spawn')
pserver_list = []
for i in range(num_servers*num_machines):
pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, i, True, num_servers))
pserver.start()
pserver_list.append(pserver)
pclient_list = []
for i in range(num_clients):
for group_id in range(num_groups):
pclient = ctx.Process(target=start_client, args=(ip_config, group_id, num_servers))
pclient.start()
pclient_list.append(pclient)
for p in pclient_list:
p.join()
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers(ip_config, num_servers)
for p in pserver_list:
p.join()
if __name__ == '__main__':
test_serialize()
test_rpc_msg()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment