Unverified Commit 5454471f authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Add roles (#1971)

* distinguish roles.

* add comments.

* fix lint.

* move roles to server_state.

* fix text.

* fix tests.

* fix tests.

* Revert "fix tests."

This reverts commit 5baa136b872a4550d4e612bfb1dfe363d7814adf.
parent 2823c61f
......@@ -8,6 +8,7 @@ from . import rpc
from .constants import MAX_QUEUE_SIZE
from .kvstore import init_kvstore, close_kvstore
from .rpc_client import connect_to_server, shutdown_servers
from .role import init_role
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = 0
......@@ -29,6 +30,7 @@ def _init_rpc(ip_config, max_queue_size, net_type, role):
'''
try:
connect_to_server(ip_config, max_queue_size, net_type)
init_role(role)
init_kvstore(ip_config, role)
except Exception as e:
print(e, flush=True)
......@@ -59,7 +61,8 @@ def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type
net_type, 'sampler'))
NUM_SAMPLER_WORKERS = num_workers
connect_to_server(ip_config, max_queue_size, net_type)
init_kvstore(ip_config)
init_role('default')
init_kvstore(ip_config, 'default')
def finalize_client():
......
......@@ -8,7 +8,7 @@ from ..heterograph import DGLHeteroGraph
from .. import heterograph_index
from .. import backend as F
from ..base import NID, EID
from .kvstore import KVServer, init_kvstore, get_kvstore
from .kvstore import KVServer, get_kvstore
from .standalone_kvstore import KVClient as SA_KVClient
from .._ffi.ndarray import empty_shared_mem
from ..frame import infer_scheme
......@@ -17,6 +17,7 @@ from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book
from .graph_partition_book import NODE_PART_POLICY, EDGE_PART_POLICY
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
from . import rpc
from . import role
from .server_state import ServerState
from .rpc_server import start_server
from .graph_services import find_edges as dist_find_edges
......@@ -360,10 +361,7 @@ class DistGraph:
self._num_edges += int(part_md['num_edges'])
def _init(self):
# Init KVStore client if it's not initialized yet.
init_kvstore(self.ip_config)
self._client = get_kvstore()
self._g = _get_graph_from_shared_mem(self.graph_name)
self._gpb = get_shared_mem_partition_book(self.graph_name, self._g)
if self._gpb is None:
......@@ -488,20 +486,7 @@ class DistGraph:
int
The rank of the current graph store.
'''
# If DistGraph doesn't have a local partition, it doesn't matter what rank
# it returns. There is no data locality any way, as long as the returned rank
# is unique in the system.
if self._g is None:
return rpc.get_rank()
else:
# If DistGraph has a local partition, we should be careful about the rank
# we return. We need to return a rank that node_split or edge_split can split
# the workload with respect to data locality.
num_client = rpc.get_num_client()
num_client_per_part = num_client // self._gpb.num_partitions()
# all ranks of the clients in the same machine are in a contiguous range.
client_id_in_part = rpc.get_rank() % num_client_per_part
return int(self._gpb.partid * num_client_per_part + client_id_in_part)
return role.get_global_rank()
def find_edges(self, edges):
""" Given an edge ID array, return the source
......@@ -590,10 +575,12 @@ def _get_overlap(mask_arr, ids):
def _split_local(partition_book, rank, elements, local_eles):
''' Split the input element list with respect to data locality.
'''
num_clients = rpc.get_num_client()
num_clients = role.get_num_trainers()
num_client_per_part = num_clients // partition_book.num_partitions()
if rank is None:
rank = rpc.get_rank()
rank = role.get_trainer_rank()
assert rank < num_clients, \
'The input rank ({}) is incorrect. #Trainers: {}'.format(rank, num_clients)
# all ranks of the clients in the same machine are in a contiguous range.
client_id_in_part = rank % num_client_per_part
local_eles = _get_overlap(elements, local_eles)
......@@ -609,11 +596,14 @@ def _split_local(partition_book, rank, elements, local_eles):
def _split_even(partition_book, rank, elements):
''' Split the input element list evenly.
'''
num_clients = rpc.get_num_client()
num_clients = role.get_num_trainers()
num_client_per_part = num_clients // partition_book.num_partitions()
if rank is None:
rank = rpc.get_rank()
# all ranks of the clients in the same machine are in a contiguous range.
if rank is None:
rank = role.get_trainer_rank()
assert rank < num_clients, \
'The input rank ({}) is incorrect. #Trainers: {}'.format(rank, num_clients)
# This conversion of rank is to make the new rank aligned with partitioning.
client_id_in_part = rank % num_client_per_part
rank = client_id_in_part + num_client_per_part * partition_book.partid
......
......@@ -214,13 +214,14 @@ class BarrierRequest(rpc.Request):
def process_request(self, server_state):
kv_store = server_state.kv_store
role = server_state.roles
count = kv_store.barrier_count[self.role]
kv_store.barrier_count[self.role] = count + 1
if kv_store.barrier_count[self.role] == len(kv_store.role[self.role]):
if kv_store.barrier_count[self.role] == len(role[self.role]):
kv_store.barrier_count[self.role] = 0
res_list = []
for target_id in kv_store.role[self.role]:
res_list.append((target_id, BarrierResponse(BARRIER_MSG)))
for client_id, _ in role[self.role]:
res_list.append((client_id, BarrierResponse(BARRIER_MSG)))
return res_list
return None
......@@ -529,59 +530,6 @@ class DeleteDataRequest(rpc.Request):
res = DeleteDataResponse(DELETE_MSG)
return res
REGISTER_ROLE = 901241
ROLE_MSG = "Register_Role"
class RegisterRoleResponse(rpc.Response):
"""Send a confirmation signal (just a short string message)
of RegisterRoleRequest to client.
"""
def __init__(self, msg):
self.msg = msg
def __getstate__(self):
return self.msg
def __setstate__(self, state):
self.msg = state
class RegisterRoleRequest(rpc.Request):
"""Send client id and role to server
Parameters
----------
client_id : int
ID of client
role : str
role of client
"""
def __init__(self, client_id, role):
self.client_id = client_id
self.role = role
def __getstate__(self):
return self.client_id, self.role
def __setstate__(self, state):
self.client_id, self.role = state
def process_request(self, server_state):
kv_store = server_state.kv_store
role = kv_store.role
if self.role not in role:
role[self.role] = set()
kv_store.barrier_count[self.role] = 0
role[self.role].add(self.client_id)
total_count = 0
for key in role:
total_count += len(role[key])
if total_count == kv_store.num_clients:
res_list = []
for target_id in range(kv_store.num_clients):
res_list.append((target_id, RegisterRoleResponse(ROLE_MSG)))
return res_list
return None
############################ KVServer ###############################
def default_push_handler(target, name, id_tensor, data_tensor):
......@@ -680,9 +628,6 @@ class KVServer(object):
rpc.register_service(DELETE_DATA,
DeleteDataRequest,
DeleteDataResponse)
rpc.register_service(REGISTER_ROLE,
RegisterRoleRequest,
RegisterRoleResponse)
# Store the tensor data with specified data name
self._data_store = {}
# Store the partition information with specified data name
......@@ -703,8 +648,6 @@ class KVServer(object):
# push and pull handler
self._push_handlers = {}
self._pull_handlers = {}
# store client role
self._role = {}
@property
def server_id(self):
......@@ -746,11 +689,6 @@ class KVServer(object):
"""Get push handler"""
return self._push_handlers
@property
def role(self):
"""Get client role"""
return self._role
@property
def pull_handlers(self):
"""Get pull handler"""
......@@ -872,9 +810,6 @@ class KVClient(object):
rpc.register_service(DELETE_DATA,
DeleteDataRequest,
DeleteDataResponse)
rpc.register_service(REGISTER_ROLE,
RegisterRoleRequest,
RegisterRoleResponse)
# Store the tensor data with specified data name
self._data_store = {}
# Store the partition information with specified data name
......@@ -897,10 +832,6 @@ class KVClient(object):
self._push_handlers = {}
# register role on server-0
self._role = role
request = RegisterRoleRequest(self._client_id, self._role)
rpc.send_request(0, request)
response = rpc.recv_response()
assert response.msg == ROLE_MSG
@property
def client_id(self):
......@@ -1298,7 +1229,7 @@ class KVClient(object):
KVCLIENT = None
def init_kvstore(ip_config, role='default'):
def init_kvstore(ip_config, role):
"""initialize KVStore"""
global KVCLIENT
if KVCLIENT is None:
......
"""Manage the roles in different clients.
Right now, the clients have different roles. Some clients work as samplers and
some work as trainers.
"""
import os
import numpy as np
from . import rpc
REGISTER_ROLE = 700001
REG_ROLE_MSG = "Register_Role"
class RegisterRoleResponse(rpc.Response):
"""Send a confirmation signal (just a short string message)
of RegisterRoleRequest to client.
"""
def __init__(self, msg):
self.msg = msg
def __getstate__(self):
return self.msg
def __setstate__(self, state):
self.msg = state
class RegisterRoleRequest(rpc.Request):
"""Send client id and role to server
Parameters
----------
client_id : int
ID of client
role : str
role of client
"""
def __init__(self, client_id, machine_id, role):
self.client_id = client_id
self.machine_id = machine_id
self.role = role
def __getstate__(self):
return self.client_id, self.machine_id, self.role
def __setstate__(self, state):
self.client_id, self.machine_id, self.role = state
def process_request(self, server_state):
kv_store = server_state.kv_store
role = server_state.roles
if self.role not in role:
role[self.role] = set()
if kv_store is not None:
kv_store.barrier_count[self.role] = 0
role[self.role].add((self.client_id, self.machine_id))
total_count = 0
for key in role:
total_count += len(role[key])
# Clients are blocked util all clients register their roles.
if total_count == rpc.get_num_client():
res_list = []
for target_id in range(rpc.get_num_client()):
res_list.append((target_id, RegisterRoleResponse(REG_ROLE_MSG)))
return res_list
return None
GET_ROLE = 700002
GET_ROLE_MSG = "Get_Role"
class GetRoleResponse(rpc.Response):
"""Send the roles of all client processes"""
def __init__(self, role):
self.role = role
self.msg = GET_ROLE_MSG
def __getstate__(self):
return self.role, self.msg
def __setstate__(self, state):
self.role, self.msg = state
class GetRoleRequest(rpc.Request):
"""Send a request to get the roles of all client processes."""
def __init__(self):
self.msg = GET_ROLE_MSG
def __getstate__(self):
return self.msg
def __setstate__(self, state):
self.msg = state
def process_request(self, server_state):
return GetRoleResponse(server_state.roles)
# The key is role, the value is a dict of mapping RPC rank to a rank within the role.
PER_ROLE_RANK = {}
# The global rank of a client process. The client processes of the same role have
# global ranks that fall in a contiguous range.
GLOBAL_RANK = {}
# The role of the current process
CUR_ROLE = None
def init_role(role):
"""Initialize the role of the current process.
Each process is associated with a role so that we can determine what
function can be invoked in a process. For example, we do not allow some
functions in sampler processes.
The initialization includes registeration the role of the current process and
get the roles of all client processes. It also computes the rank of all client
processes in a deterministic way so that all clients will have the same rank for
the same client process.
"""
global CUR_ROLE
CUR_ROLE = role
global PER_ROLE_RANK
global GLOBAL_RANK
if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
assert role == 'default'
GLOBAL_RANK[0] = 0
PER_ROLE_RANK['default'] = {0:0}
# Register the current role. This blocks until all clients register themselves.
client_id = rpc.get_rank()
machine_id = rpc.get_machine_id()
request = RegisterRoleRequest(client_id, machine_id, role)
rpc.send_request(0, request)
response = rpc.recv_response()
assert response.msg == REG_ROLE_MSG
# Get all clients on all machines.
request = GetRoleRequest()
rpc.send_request(0, request)
response = rpc.recv_response()
assert response.msg == GET_ROLE_MSG
# Here we want to compute a new rank for each client.
# We compute the per-role rank as well as global rank.
# For per-role rank, we ensure that all ranks within a machine is contiguous.
# For global rank, we also ensure that all ranks within a machine are contiguous,
# and all ranks within a role are contiguous.
global_rank = 0
# We want to ensure that the global rank of the trainer process starts from 0.
role_names = ['default']
for role_name in response.role:
if role_name not in role_names:
role_names.append(role_name)
for role_name in role_names:
# Let's collect the ranks of this role in all machines.
machines = {}
for client_id, machine_id in response.role[role_name]:
if machine_id not in machines:
machines[machine_id] = []
machines[machine_id].append(client_id)
num_machines = len(machines)
PER_ROLE_RANK[role_name] = {}
per_role_rank = 0
for i in range(num_machines):
clients = machines[i]
clients = np.sort(clients)
for client_id in clients:
GLOBAL_RANK[client_id] = global_rank
global_rank += 1
PER_ROLE_RANK[role_name][client_id] = per_role_rank
per_role_rank += 1
def get_global_rank():
"""Get the global rank
The rank can globally identify the client process. For the client processes
of the same role, their ranks are in a contiguous range.
"""
return GLOBAL_RANK[rpc.get_rank()]
def get_rank(role):
"""Get the role-specific rank"""
return PER_ROLE_RANK[role][rpc.get_rank()]
def get_trainer_rank():
"""Get the rank of the current trainer process.
This function can only be called in the trainer process. It will result in
an error if it's called in the process of other roles.
"""
assert CUR_ROLE == 'default'
return PER_ROLE_RANK['default'][rpc.get_rank()]
def get_role():
"""Get the role of the current process"""
return CUR_ROLE
def get_num_trainers():
"""Get the number of trainer processes"""
return len(PER_ROLE_RANK['default'])
rpc.register_service(REGISTER_ROLE, RegisterRoleRequest, RegisterRoleResponse)
rpc.register_service(GET_ROLE, GetRoleRequest, GetRoleResponse)
......@@ -44,6 +44,12 @@ class ServerState:
self._kv_store = kv_store
self._graph = local_g
self.partition_book = partition_book
self._roles = {}
@property
def roles(self):
"""Roles of the client processes"""
return self._roles
@property
def kv_store(self):
......
......@@ -239,6 +239,12 @@ def test_server_client():
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_standalone():
os.environ['DGL_DIST_MODE'] = 'standalone'
# TODO(zhengda) this is a temporary fix. We need to make initialize work
# for standalone mode as well.
dgl.distributed.role.CUR_ROLE = 'default'
dgl.distributed.role.GLOBAL_RANK = {-1:0}
dgl.distributed.role.PER_ROLE_RANK['default'] = {-1:0}
g = create_random_graph(10000)
# Partition the graph
num_parts = 1
......@@ -261,8 +267,17 @@ def test_split():
edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
selected_nodes = np.nonzero(node_mask)[0]
selected_edges = np.nonzero(edge_mask)[0]
# The code now collects the roles of all client processes and use the information
# to determine how to split the workloads. Here is to simulate the multi-client
# use case.
def set_roles(num_clients):
dgl.distributed.role.CUR_ROLE = 'default'
dgl.distributed.role.GLOBAL_RANK = {i:i for i in range(num_clients)}
dgl.distributed.role.PER_ROLE_RANK['default'] = {i:i for i in range(num_clients)}
for i in range(num_parts):
dgl.distributed.set_num_client(num_parts)
set_roles(num_parts)
part_g, node_feats, edge_feats, gpb, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
......@@ -273,13 +288,13 @@ def test_split():
for n in nodes1:
assert n in local_nids
dgl.distributed.set_num_client(num_parts * 2)
set_roles(num_parts * 2)
nodes3 = node_split(node_mask, gpb, i * 2, force_even=False)
nodes4 = node_split(node_mask, gpb, i * 2 + 1, force_even=False)
nodes5 = F.cat([nodes3, nodes4], 0)
assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))
dgl.distributed.set_num_client(num_parts)
set_roles(num_parts)
local_eids = F.nonzero_1d(part_g.edata['inner_edge'])
local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
edges1 = np.intersect1d(selected_edges, F.asnumpy(local_eids))
......@@ -289,7 +304,7 @@ def test_split():
for e in edges1:
assert e in local_eids
dgl.distributed.set_num_client(num_parts * 2)
set_roles(num_parts * 2)
edges3 = edge_split(edge_mask, gpb, i * 2, force_even=False)
edges4 = edge_split(edge_mask, gpb, i * 2 + 1, force_even=False)
edges5 = F.cat([edges3, edges4], 0)
......@@ -310,8 +325,17 @@ def test_split_even():
all_nodes2 = []
all_edges1 = []
all_edges2 = []
# The code now collects the roles of all client processes and use the information
# to determine how to split the workloads. Here is to simulate the multi-client
# use case.
def set_roles(num_clients):
dgl.distributed.role.CUR_ROLE = 'default'
dgl.distributed.role.GLOBAL_RANK = {i:i for i in range(num_clients)}
dgl.distributed.role.PER_ROLE_RANK['default'] = {i:i for i in range(num_clients)}
for i in range(num_parts):
dgl.distributed.set_num_client(num_parts)
set_roles(num_parts)
part_g, node_feats, edge_feats, gpb, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
......@@ -320,7 +344,7 @@ def test_split_even():
subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
print('part {} get {} nodes and {} are in the partition'.format(i, len(nodes), len(subset)))
dgl.distributed.set_num_client(num_parts * 2)
set_roles(num_parts * 2)
nodes1 = node_split(node_mask, gpb, i * 2, force_even=True)
nodes2 = node_split(node_mask, gpb, i * 2 + 1, force_even=True)
nodes3 = F.cat([nodes1, nodes2], 0)
......@@ -328,7 +352,7 @@ def test_split_even():
subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
print('intersection has', len(subset))
dgl.distributed.set_num_client(num_parts)
set_roles(num_parts)
local_eids = F.nonzero_1d(part_g.edata['inner_edge'])
local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
edges = edge_split(edge_mask, gpb, i, force_even=True)
......@@ -336,7 +360,7 @@ def test_split_even():
subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
print('part {} get {} edges and {} are in the partition'.format(i, len(edges), len(subset)))
dgl.distributed.set_num_client(num_parts * 2)
set_roles(num_parts * 2)
edges1 = edge_split(edge_mask, gpb, i * 2, force_even=True)
edges2 = edge_split(edge_mask, gpb, i * 2 + 1, force_even=True)
edges3 = F.cat([edges1, edges2], 0)
......
......@@ -5,7 +5,7 @@ import socket
from scipy import sparse as spsp
import dgl
import backend as F
import unittest, pytest
import unittest
from dgl.graph_index import create_graph_index
import multiprocessing as mp
from numpy.testing import assert_array_equal
......@@ -147,8 +147,9 @@ def start_server_mul_role(server_id, num_clients):
server_state=server_state)
def start_client(num_clients):
os.environ['DGL_DIST_MODE'] = 'distributed'
# Note: connect to server first !
dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt')
dgl.distributed.initialize(ip_config='kv_ip_config.txt')
# Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt')
kvclient.map_shared_data(partition_book=gpb)
......@@ -275,19 +276,22 @@ def start_client(num_clients):
data_tensor = data_tensor * num_clients
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
def start_client_mul_role(i, num_clients):
# Note: connect to server first !
dgl.distributed.connect_to_server(ip_config='kv_ip_mul_config.txt')
# Init kvclient
if i % 2 == 0:
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='trainer')
else:
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='sampler')
if i == 2: # block one trainer
def start_client_mul_role(i, num_workers):
os.environ['DGL_DIST_MODE'] = 'distributed'
# Initialize creates kvstore !
dgl.distributed.initialize(ip_config='kv_ip_mul_config.txt', num_workers=num_workers)
if i == 0: # block one trainer
time.sleep(5)
kvclient = dgl.distributed.kvstore.get_kvstore()
kvclient.barrier()
print("i: %d role: %s" % (i, kvclient.role))
assert dgl.distributed.role.get_num_trainers() == 2
assert dgl.distributed.role.get_trainer_rank() < 2
print('trainer rank: %d, global rank: %d' % (dgl.distributed.role.get_trainer_rank(),
dgl.distributed.role.get_global_rank()))
dgl.distributed.exit_client()
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_kv_store():
ip_config = open("kv_ip_config.txt", "w")
......@@ -316,7 +320,10 @@ def test_kv_store():
def test_kv_multi_role():
ip_config = open("kv_ip_mul_config.txt", "w")
num_servers = 2
num_clients = 10
num_trainers = 2
num_samplers = 2
# There are two trainer processes and each trainer process has two sampler processes.
num_clients = num_trainers * (1 + num_samplers)
ip_addr = get_local_usable_addr()
ip_config.write('{} {}\n'.format(ip_addr, num_servers))
ip_config.close()
......@@ -327,11 +334,11 @@ def test_kv_multi_role():
pserver = ctx.Process(target=start_server_mul_role, args=(i, num_clients))
pserver.start()
pserver_list.append(pserver)
for i in range(num_clients):
pclient = ctx.Process(target=start_client_mul_role, args=(i, num_clients))
for i in range(num_trainers):
pclient = ctx.Process(target=start_client_mul_role, args=(i, num_samplers))
pclient.start()
pclient_list.append(pclient)
for i in range(num_clients):
for i in range(num_trainers):
pclient_list[i].join()
for i in range(num_servers):
pserver_list[i].join()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment