Unverified Commit a208e886 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4680)



* [Misc] Black auto fix.

* fix pylint disable
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 29434e65
"""DGL distributed module"""
from .dist_graph import DistGraphServer, DistGraph, node_split, edge_split
from . import optim
from .dist_context import exit_client, initialize
from .dist_dataloader import DistDataLoader
from .dist_graph import DistGraph, DistGraphServer, edge_split, node_split
from .dist_tensor import DistTensor
from .partition import partition_graph, load_partition, load_partition_feats, load_partition_book
from .graph_partition_book import GraphPartitionBook, PartitionPolicy
from .graph_services import *
from .kvstore import KVClient, KVServer
from .nn import *
from . import optim
from .partition import (
load_partition,
load_partition_book,
load_partition_feats,
partition_graph,
)
from .rpc import *
from .rpc_server import start_server
from .rpc_client import connect_to_server, shutdown_servers
from .dist_context import initialize, exit_client
from .kvstore import KVServer, KVClient
from .rpc_server import start_server
from .server_state import ServerState
from .dist_dataloader import DistDataLoader
from .graph_services import *
"""Define all the constants used by DGL rpc"""
# Maximum size of message queue in bytes
MAX_QUEUE_SIZE = 20*1024*1024*1024
MAX_QUEUE_SIZE = 20 * 1024 * 1024 * 1024
SERVER_EXIT = "server_exit"
SERVER_KEEP_ALIVE = "server_keep_alive"
DEFAULT_NTYPE = '_N'
DEFAULT_ETYPE = (DEFAULT_NTYPE, '_E', DEFAULT_NTYPE)
DEFAULT_NTYPE = "_N"
DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE)
"""Initialize the distributed services"""
# pylint: disable=line-too-long
import multiprocessing as mp
import traceback
import atexit
import time
import gc
import multiprocessing as mp
import os
import sys
import queue
import gc
import sys
import time
import traceback
from enum import Enum
from .. import utils
from . import rpc
from .constants import MAX_QUEUE_SIZE
from .kvstore import init_kvstore, close_kvstore
from .rpc_client import connect_to_server
from .kvstore import close_kvstore, init_kvstore
from .role import init_role
from .. import utils
from .rpc_client import connect_to_server
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = 0
......@@ -34,13 +34,22 @@ def get_sampler_pool():
return SAMPLER_POOL, NUM_SAMPLER_WORKERS
def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads, group_id):
''' This init function is called in the worker processes.
'''
def _init_rpc(
ip_config,
num_servers,
max_queue_size,
net_type,
role,
num_threads,
group_id,
):
"""This init function is called in the worker processes."""
try:
utils.set_num_threads(num_threads)
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
connect_to_server(ip_config, num_servers, max_queue_size, net_type, group_id)
if os.environ.get("DGL_DIST_MODE", "standalone") != "standalone":
connect_to_server(
ip_config, num_servers, max_queue_size, net_type, group_id
)
init_role(role)
init_kvstore(ip_config, num_servers, role)
except Exception as e:
......@@ -51,6 +60,7 @@ def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_thread
class MpCommand(Enum):
"""Enum class for multiprocessing command"""
INIT_RPC = 0 # Not used in the task queue
SET_COLLATE_FN = 1
CALL_BARRIER = 2
......@@ -80,12 +90,16 @@ def init_process(rpc_config, mp_contexts):
elif command == MpCommand.CALL_BARRIER:
barrier.wait()
elif command == MpCommand.DELETE_COLLATE_FN:
dataloader_name, = args
(dataloader_name,) = args
del collate_fn_dict[dataloader_name]
elif command == MpCommand.CALL_COLLATE_FN:
dataloader_name, collate_args = args
data_queue.put(
(dataloader_name, collate_fn_dict[dataloader_name](collate_args)))
(
dataloader_name,
collate_fn_dict[dataloader_name](collate_args),
)
)
elif command == MpCommand.CALL_FN_ALL_WORKERS:
func, func_args = args
func(func_args)
......@@ -101,6 +115,7 @@ def init_process(rpc_config, mp_contexts):
class CustomPool:
"""Customized worker pool"""
def __init__(self, num_workers, rpc_config):
"""
Customized worker pool init function
......@@ -117,8 +132,13 @@ class CustomPool:
for _ in range(num_workers):
task_queue = ctx.Queue(self.queue_size)
self.task_queues.append(task_queue)
proc = ctx.Process(target=init_process, args=(
rpc_config, (self.result_queue, task_queue, self.barrier)))
proc = ctx.Process(
target=init_process,
args=(
rpc_config,
(self.result_queue, task_queue, self.barrier),
),
)
proc.daemon = True
proc.start()
self.process_list.append(proc)
......@@ -127,20 +147,23 @@ class CustomPool:
"""Set collate function in subprocess"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.SET_COLLATE_FN, (dataloader_name, func)))
(MpCommand.SET_COLLATE_FN, (dataloader_name, func))
)
def submit_task(self, dataloader_name, args):
"""Submit task to workers"""
# Round robin
self.task_queues[self.current_proc_id].put(
(MpCommand.CALL_COLLATE_FN, (dataloader_name, args)))
(MpCommand.CALL_COLLATE_FN, (dataloader_name, args))
)
self.current_proc_id = (self.current_proc_id + 1) % self.num_workers
def submit_task_to_all_workers(self, func, args):
"""Submit task to all workers"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.CALL_FN_ALL_WORKERS, (func, args)))
(MpCommand.CALL_FN_ALL_WORKERS, (func, args))
)
def get_result(self, dataloader_name, timeout=1800):
"""Get result from result queue"""
......@@ -152,20 +175,21 @@ class CustomPool:
"""Delete collate function"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.DELETE_COLLATE_FN, (dataloader_name, )))
(MpCommand.DELETE_COLLATE_FN, (dataloader_name,))
)
def call_barrier(self):
"""Call barrier at all workers"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.CALL_BARRIER, tuple()))
self.task_queues[i].put((MpCommand.CALL_BARRIER, tuple()))
def close(self):
"""Close worker pool"""
for i in range(self.num_workers):
self.task_queues[i].put((MpCommand.FINALIZE_POOL, tuple()), block=False)
time.sleep(0.5) # Fix for early python version
self.task_queues[i].put(
(MpCommand.FINALIZE_POOL, tuple()), block=False
)
time.sleep(0.5) # Fix for early python version
def join(self):
"""Join the close process of worker pool"""
......@@ -173,8 +197,12 @@ class CustomPool:
self.process_list[i].join()
def initialize(ip_config, max_queue_size=MAX_QUEUE_SIZE,
net_type='socket', num_worker_threads=1):
def initialize(
ip_config,
max_queue_size=MAX_QUEUE_SIZE,
net_type="socket",
num_worker_threads=1,
):
"""Initialize DGL's distributed module
This function initializes DGL's distributed module. It acts differently in server
......@@ -204,59 +232,84 @@ def initialize(ip_config, max_queue_size=MAX_QUEUE_SIZE,
distributed API. For example, when used with Pytorch, users have to invoke this function
before Pytorch's `pytorch.distributed.init_process_group`.
"""
if os.environ.get('DGL_ROLE', 'client') == 'server':
if os.environ.get("DGL_ROLE", "client") == "server":
from .dist_graph import DistGraphServer
assert os.environ.get('DGL_SERVER_ID') is not None, \
'Please define DGL_SERVER_ID to run DistGraph server'
assert os.environ.get('DGL_IP_CONFIG') is not None, \
'Please define DGL_IP_CONFIG to run DistGraph server'
assert os.environ.get('DGL_NUM_SERVER') is not None, \
'Please define DGL_NUM_SERVER to run DistGraph server'
assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
assert (
os.environ.get("DGL_SERVER_ID") is not None
), "Please define DGL_SERVER_ID to run DistGraph server"
assert (
os.environ.get("DGL_IP_CONFIG") is not None
), "Please define DGL_IP_CONFIG to run DistGraph server"
assert (
os.environ.get("DGL_NUM_SERVER") is not None
), "Please define DGL_NUM_SERVER to run DistGraph server"
assert (
os.environ.get("DGL_NUM_CLIENT") is not None
), "Please define DGL_NUM_CLIENT to run DistGraph server"
assert (
os.environ.get("DGL_CONF_PATH") is not None
), "Please define DGL_CONF_PATH to run DistGraph server"
formats = os.environ.get("DGL_GRAPH_FORMAT", "csc").split(",")
formats = [f.strip() for f in formats]
rpc.reset()
keep_alive = bool(int(os.environ.get('DGL_KEEP_ALIVE', 0)))
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'),
graph_format=formats,
keep_alive=keep_alive,
net_type=net_type)
keep_alive = bool(int(os.environ.get("DGL_KEEP_ALIVE", 0)))
serv = DistGraphServer(
int(os.environ.get("DGL_SERVER_ID")),
os.environ.get("DGL_IP_CONFIG"),
int(os.environ.get("DGL_NUM_SERVER")),
int(os.environ.get("DGL_NUM_CLIENT")),
os.environ.get("DGL_CONF_PATH"),
graph_format=formats,
keep_alive=keep_alive,
net_type=net_type,
)
serv.start()
sys.exit()
else:
num_workers = int(os.environ.get('DGL_NUM_SAMPLER', 0))
num_servers = int(os.environ.get('DGL_NUM_SERVER', 1))
group_id = int(os.environ.get('DGL_GROUP_ID', 0))
num_workers = int(os.environ.get("DGL_NUM_SAMPLER", 0))
num_servers = int(os.environ.get("DGL_NUM_SERVER", 1))
group_id = int(os.environ.get("DGL_GROUP_ID", 0))
rpc.reset()
global SAMPLER_POOL
global NUM_SAMPLER_WORKERS
is_standalone = os.environ.get(
'DGL_DIST_MODE', 'standalone') == 'standalone'
is_standalone = (
os.environ.get("DGL_DIST_MODE", "standalone") == "standalone"
)
if num_workers > 0 and not is_standalone:
SAMPLER_POOL = CustomPool(num_workers, (ip_config, num_servers, max_queue_size,
net_type, 'sampler', num_worker_threads,
group_id))
SAMPLER_POOL = CustomPool(
num_workers,
(
ip_config,
num_servers,
max_queue_size,
net_type,
"sampler",
num_worker_threads,
group_id,
),
)
else:
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers
if not is_standalone:
assert num_servers is not None and num_servers > 0, \
'The number of servers per machine must be specified with a positive number.'
connect_to_server(ip_config, num_servers, max_queue_size, net_type, group_id=group_id)
init_role('default')
init_kvstore(ip_config, num_servers, 'default')
assert (
num_servers is not None and num_servers > 0
), "The number of servers per machine must be specified with a positive number."
connect_to_server(
ip_config,
num_servers,
max_queue_size,
net_type,
group_id=group_id,
)
init_role("default")
init_kvstore(ip_config, num_servers, "default")
def finalize_client():
"""Release resources of this client."""
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
if os.environ.get("DGL_DIST_MODE", "standalone") != "standalone":
rpc.finalize_sender()
rpc.finalize_receiver()
......@@ -268,7 +321,7 @@ def _exit():
def finalize_worker():
"""Finalize workers
Python's multiprocessing pool will not call atexit function when close
Python's multiprocessing pool will not call atexit function when close
"""
global SAMPLER_POOL
if SAMPLER_POOL is not None:
......@@ -284,8 +337,7 @@ def join_finalize_worker():
def is_initialized():
"""Is RPC initialized?
"""
"""Is RPC initialized?"""
return INITIALIZED
......@@ -297,6 +349,7 @@ def _shutdown_servers():
for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req)
def exit_client():
"""Trainer exits
......@@ -308,12 +361,15 @@ def exit_client():
needs to call `exit_client` before calling `initialize` again.
"""
# Only client with rank_0 will send shutdown request to servers.
print("Client[{}] in group[{}] is exiting...".format(
rpc.get_rank(), rpc.get_group_id()))
print(
"Client[{}] in group[{}] is exiting...".format(
rpc.get_rank(), rpc.get_group_id()
)
)
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
# collect data such as DistTensor before exit
gc.collect()
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
if os.environ.get("DGL_DIST_MODE", "standalone") != "standalone":
rpc.client_barrier()
_shutdown_servers()
finalize_client()
......
# pylint: disable=global-variable-undefined, invalid-name
"""Multiprocess dataloader for distributed training"""
from .dist_context import get_sampler_pool
from .. import backend as F
from .dist_context import get_sampler_pool
__all__ = ["DistDataLoader"]
......@@ -61,8 +61,15 @@ class DistDataLoader:
and [3, 4] is not guaranteed.
"""
def __init__(self, dataset, batch_size, shuffle=False, collate_fn=None, drop_last=False,
queue_size=None):
def __init__(
self,
dataset,
batch_size,
shuffle=False,
collate_fn=None,
drop_last=False,
queue_size=None,
):
self.pool, self.num_workers = get_sampler_pool()
if queue_size is None:
queue_size = self.num_workers * 4 if self.num_workers > 0 else 4
......@@ -71,7 +78,7 @@ class DistDataLoader:
self.num_pending = 0
self.collate_fn = collate_fn
self.current_pos = 0
self.queue = [] # Only used when pool is None
self.queue = [] # Only used when pool is None
self.drop_last = drop_last
self.recv_idxs = 0
self.shuffle = shuffle
......@@ -153,7 +160,7 @@ class DistDataLoader:
end_pos = len(self.dataset)
else:
end_pos = self.current_pos + self.batch_size
idx = self.data_idx[self.current_pos:end_pos].tolist()
idx = self.data_idx[self.current_pos : end_pos].tolist()
ret = [self.dataset[i] for i in idx]
# Sharing large number of tensors between processes will consume too many
# file descriptors, so let's convert each tensor to scalar value beforehand.
......
......@@ -3,17 +3,24 @@
import pickle
from abc import ABC
from ast import literal_eval
import numpy as np
from .. import backend as F
from ..base import NID, EID, DGLError, dgl_warning
from .. import utils
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
from .._ffi.ndarray import empty_shared_mem
from ..base import EID, NID, DGLError, dgl_warning
from ..ndarray import exist_shared_mem_array
from ..partition import NDArrayPartition
from .constants import DEFAULT_ETYPE, DEFAULT_NTYPE
from .id_map import IdMap
from .constants import DEFAULT_NTYPE, DEFAULT_ETYPE
from .shared_mem_utils import (
DTYPE_DICT,
_get_edata_path,
_get_ndata_path,
_to_shared_mem,
)
def _str_to_tuple(s):
try:
......@@ -22,9 +29,18 @@ def _str_to_tuple(s):
ret = s
return ret
def _move_metadata_to_shared_mem(graph_name, num_nodes, num_edges, part_id,
num_partitions, node_map, edge_map, is_range_part):
''' Move all metadata of the partition book to the shared memory.
def _move_metadata_to_shared_mem(
graph_name,
num_nodes,
num_edges,
part_id,
num_partitions,
node_map,
edge_map,
is_range_part,
):
"""Move all metadata of the partition book to the shared memory.
These metadata will be used to construct graph partition book.
......@@ -56,17 +72,28 @@ def _move_metadata_to_shared_mem(graph_name, num_nodes, num_edges, part_id,
The first tensor stores the serialized metadata, the second tensor stores the serialized
node map and the third tensor stores the serialized edge map. All tensors are stored in
shared memory.
'''
meta = _to_shared_mem(F.tensor([int(is_range_part), num_nodes, num_edges,
num_partitions, part_id,
len(node_map), len(edge_map)]),
_get_ndata_path(graph_name, 'meta'))
node_map = _to_shared_mem(node_map, _get_ndata_path(graph_name, 'node_map'))
edge_map = _to_shared_mem(edge_map, _get_edata_path(graph_name, 'edge_map'))
"""
meta = _to_shared_mem(
F.tensor(
[
int(is_range_part),
num_nodes,
num_edges,
num_partitions,
part_id,
len(node_map),
len(edge_map),
]
),
_get_ndata_path(graph_name, "meta"),
)
node_map = _to_shared_mem(node_map, _get_ndata_path(graph_name, "node_map"))
edge_map = _to_shared_mem(edge_map, _get_edata_path(graph_name, "edge_map"))
return meta, node_map, edge_map
def _get_shared_mem_metadata(graph_name):
''' Get the metadata of the graph from shared memory.
"""Get the metadata of the graph from shared memory.
The server serializes the metadata of a graph and store them in shared memory.
The client needs to deserialize the data in shared memory and get the metadata
......@@ -85,24 +112,38 @@ def _get_shared_mem_metadata(graph_name):
the third element is the number of partitions;
the fourth element is the tensor that stores the serialized result of node maps;
the fifth element is the tensor that stores the serialized result of edge maps.
'''
"""
# The metadata has 7 elements: is_range_part, num_nodes, num_edges, num_partitions, part_id,
# the length of node map and the length of the edge map.
shape = (7,)
dtype = F.int64
dtype = DTYPE_DICT[dtype]
data = empty_shared_mem(_get_ndata_path(graph_name, 'meta'), False, shape, dtype)
data = empty_shared_mem(
_get_ndata_path(graph_name, "meta"), False, shape, dtype
)
dlpack = data.to_dlpack()
meta = F.asnumpy(F.zerocopy_from_dlpack(dlpack))
is_range_part, _, _, num_partitions, part_id, node_map_len, edge_map_len = meta
(
is_range_part,
_,
_,
num_partitions,
part_id,
node_map_len,
edge_map_len,
) = meta
# Load node map
data = empty_shared_mem(_get_ndata_path(graph_name, 'node_map'), False, (node_map_len,), dtype)
data = empty_shared_mem(
_get_ndata_path(graph_name, "node_map"), False, (node_map_len,), dtype
)
dlpack = data.to_dlpack()
node_map = F.zerocopy_from_dlpack(dlpack)
# Load edge_map
data = empty_shared_mem(_get_edata_path(graph_name, 'edge_map'), False, (edge_map_len,), dtype)
data = empty_shared_mem(
_get_edata_path(graph_name, "edge_map"), False, (edge_map_len,), dtype
)
dlpack = data.to_dlpack()
edge_map = F.zerocopy_from_dlpack(dlpack)
......@@ -110,7 +151,7 @@ def _get_shared_mem_metadata(graph_name):
def get_shared_mem_partition_book(graph_name, graph_part):
'''Get a graph partition book from shared memory.
"""Get a graph partition book from shared memory.
A graph partition book of a specific graph can be serialized to shared memory.
We can reconstruct a graph partition book from shared memory.
......@@ -126,11 +167,16 @@ def get_shared_mem_partition_book(graph_name, graph_part):
-------
GraphPartitionBook
A graph partition book for a particular partition.
'''
if not exist_shared_mem_array(_get_ndata_path(graph_name, 'meta')):
"""
if not exist_shared_mem_array(_get_ndata_path(graph_name, "meta")):
return None
is_range_part, part_id, num_parts, node_map_data, edge_map_data = \
_get_shared_mem_metadata(graph_name)
(
is_range_part,
part_id,
num_parts,
node_map_data,
edge_map_data,
) = _get_shared_mem_metadata(graph_name)
if is_range_part == 1:
# node ID ranges and edge ID ranges are stored in the order of node type IDs
# and edge type IDs.
......@@ -150,12 +196,17 @@ def get_shared_mem_partition_book(graph_name, graph_part):
for i, (etype, eid_range) in enumerate(edge_map_data):
etypes[etype] = i
edge_map[etype] = eid_range
return RangePartitionBook(part_id, num_parts, node_map, edge_map, ntypes, etypes)
return RangePartitionBook(
part_id, num_parts, node_map, edge_map, ntypes, etypes
)
else:
return BasicPartitionBook(part_id, num_parts, node_map_data, edge_map_data, graph_part)
return BasicPartitionBook(
part_id, num_parts, node_map_data, edge_map_data, graph_part
)
def get_node_partition_from_book(book, device):
""" Get an NDArrayPartition of the nodes from a RangePartitionBook.
"""Get an NDArrayPartition of the nodes from a RangePartitionBook.
Parameters
----------
......@@ -169,25 +220,27 @@ def get_node_partition_from_book(book, device):
NDarrayPartition
The NDArrayPartition object for the nodes in the graph.
"""
assert isinstance(book, RangePartitionBook), "Can only convert " \
"RangePartitionBook to NDArrayPartition."
assert isinstance(book, RangePartitionBook), (
"Can only convert " "RangePartitionBook to NDArrayPartition."
)
# create prefix-sum array on host
max_node_ids = F.zerocopy_from_numpy(book._max_node_ids)
cpu_range = F.cat([F.tensor([0], dtype=F.dtype(max_node_ids)),
max_node_ids+1], dim=0)
cpu_range = F.cat(
[F.tensor([0], dtype=F.dtype(max_node_ids)), max_node_ids + 1], dim=0
)
gpu_range = F.copy_to(cpu_range, ctx=device)
# convert from numpy
array_size = int(F.as_scalar(cpu_range[-1]))
num_parts = book.num_partitions()
return NDArrayPartition(array_size,
num_parts,
mode='range',
part_ranges=gpu_range)
return NDArrayPartition(
array_size, num_parts, mode="range", part_ranges=gpu_range
)
class GraphPartitionBook(ABC):
""" The base class of the graph partition book.
"""The base class of the graph partition book.
For distributed training, a graph is partitioned into multiple parts and is loaded
in multiple machines. The partition book contains all necessary information to locate
......@@ -368,13 +421,11 @@ class GraphPartitionBook(ABC):
@property
def ntypes(self):
"""Get the list of node types
"""
"""Get the list of node types"""
@property
def etypes(self):
"""Get the list of edge types
"""
"""Get the list of edge types"""
@property
def canonical_etypes(self):
......@@ -388,9 +439,8 @@ class GraphPartitionBook(ABC):
@property
def is_homogeneous(self):
"""check if homogeneous
"""
return not(len(self.etypes) > 1 or len(self.ntypes) > 1)
"""check if homogeneous"""
return not (len(self.etypes) > 1 or len(self.ntypes) > 1)
def map_to_per_ntype(self, ids):
"""Map homogeneous node IDs to type-wise IDs and node types.
......@@ -452,6 +502,7 @@ class GraphPartitionBook(ABC):
Homogeneous edge IDs.
"""
class BasicPartitionBook(GraphPartitionBook):
"""This provides the most flexible way to store parition information.
......@@ -473,33 +524,40 @@ class BasicPartitionBook(GraphPartitionBook):
part_graph : DGLGraph
The graph partition structure.
"""
def __init__(self, part_id, num_parts, node_map, edge_map, part_graph):
assert part_id >= 0, 'part_id cannot be a negative number.'
assert num_parts > 0, 'num_parts must be greater than zero.'
assert part_id >= 0, "part_id cannot be a negative number."
assert num_parts > 0, "num_parts must be greater than zero."
self._part_id = int(part_id)
self._num_partitions = int(num_parts)
self._nid2partid = F.tensor(node_map)
assert F.dtype(self._nid2partid) == F.int64, \
'the node map must be stored in an integer array'
assert (
F.dtype(self._nid2partid) == F.int64
), "the node map must be stored in an integer array"
self._eid2partid = F.tensor(edge_map)
assert F.dtype(self._eid2partid) == F.int64, \
'the edge map must be stored in an integer array'
assert (
F.dtype(self._eid2partid) == F.int64
), "the edge map must be stored in an integer array"
# Get meta data of the partition book.
self._partition_meta_data = []
_, nid_count = np.unique(F.asnumpy(self._nid2partid), return_counts=True)
_, eid_count = np.unique(F.asnumpy(self._eid2partid), return_counts=True)
_, nid_count = np.unique(
F.asnumpy(self._nid2partid), return_counts=True
)
_, eid_count = np.unique(
F.asnumpy(self._eid2partid), return_counts=True
)
for partid in range(self._num_partitions):
part_info = {}
part_info['machine_id'] = partid
part_info['num_nodes'] = int(nid_count[partid])
part_info['num_edges'] = int(eid_count[partid])
part_info["machine_id"] = partid
part_info["num_nodes"] = int(nid_count[partid])
part_info["num_edges"] = int(eid_count[partid])
self._partition_meta_data.append(part_info)
# Get partid2nids
self._partid2nids = []
sorted_nid = F.tensor(np.argsort(F.asnumpy(self._nid2partid)))
start = 0
for offset in nid_count:
part_nids = sorted_nid[start:start+offset]
part_nids = sorted_nid[start : start + offset]
start += offset
self._partid2nids.append(part_nids)
# Get partid2eids
......@@ -507,7 +565,7 @@ class BasicPartitionBook(GraphPartitionBook):
sorted_eid = F.tensor(np.argsort(F.asnumpy(self._eid2partid)))
start = 0
for offset in eid_count:
part_eids = sorted_eid[start:start+offset]
part_eids = sorted_eid[start : start + offset]
start += offset
self._partid2eids.append(part_eids)
# Get nidg2l
......@@ -515,7 +573,7 @@ class BasicPartitionBook(GraphPartitionBook):
global_id = part_graph.ndata[NID]
max_global_id = np.amax(F.asnumpy(global_id))
# TODO(chao): support int32 index
g2l = F.zeros((max_global_id+1), F.int64, F.context(global_id))
g2l = F.zeros((max_global_id + 1), F.int64, F.context(global_id))
g2l = F.scatter_row(g2l, global_id, F.arange(0, len(global_id)))
self._nidg2l[self._part_id] = g2l
# Get eidg2l
......@@ -523,7 +581,7 @@ class BasicPartitionBook(GraphPartitionBook):
global_id = part_graph.edata[EID]
max_global_id = np.amax(F.asnumpy(global_id))
# TODO(chao): support int32 index
g2l = F.zeros((max_global_id+1), F.int64, F.context(global_id))
g2l = F.zeros((max_global_id + 1), F.int64, F.context(global_id))
g2l = F.scatter_row(g2l, global_id, F.arange(0, len(global_id)))
self._eidg2l[self._part_id] = g2l
# node size and edge size
......@@ -531,33 +589,43 @@ class BasicPartitionBook(GraphPartitionBook):
self._node_size = len(self.partid2nids(self._part_id))
def shared_memory(self, graph_name):
"""Move data to shared memory.
"""
self._meta, self._nid2partid, self._eid2partid = _move_metadata_to_shared_mem(
graph_name, self._num_nodes(), self._num_edges(), self._part_id, self._num_partitions,
self._nid2partid, self._eid2partid, False)
"""Move data to shared memory."""
(
self._meta,
self._nid2partid,
self._eid2partid,
) = _move_metadata_to_shared_mem(
graph_name,
self._num_nodes(),
self._num_edges(),
self._part_id,
self._num_partitions,
self._nid2partid,
self._eid2partid,
False,
)
def num_partitions(self):
"""Return the number of partitions.
"""
"""Return the number of partitions."""
return self._num_partitions
def metadata(self):
"""Return the partition meta data.
"""
"""Return the partition meta data."""
return self._partition_meta_data
def _num_nodes(self, ntype=DEFAULT_NTYPE):
""" The total number of nodes
"""
assert ntype == DEFAULT_NTYPE, 'Base partition book only supports homogeneous graph.'
"""The total number of nodes"""
assert (
ntype == DEFAULT_NTYPE
), "Base partition book only supports homogeneous graph."
return len(self._nid2partid)
def _num_edges(self, etype=DEFAULT_ETYPE):
""" The total number of edges
"""
assert etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]), \
'Base partition book only supports homogeneous graph.'
"""The total number of edges"""
assert etype in (
DEFAULT_ETYPE,
DEFAULT_ETYPE[1],
), "Base partition book only supports homogeneous graph."
return len(self._eid2partid)
def map_to_per_ntype(self, ids):
......@@ -575,79 +643,88 @@ class BasicPartitionBook(GraphPartitionBook):
return F.zeros((len(ids),), F.int32, F.cpu()), ids
def map_to_homo_nid(self, ids, ntype=DEFAULT_NTYPE):
"""Map per-node-type IDs to global node IDs in the homogeneous format.
"""
assert ntype == DEFAULT_NTYPE, 'Base partition book only supports homogeneous graph.'
"""Map per-node-type IDs to global node IDs in the homogeneous format."""
assert (
ntype == DEFAULT_NTYPE
), "Base partition book only supports homogeneous graph."
return ids
def map_to_homo_eid(self, ids, etype=DEFAULT_ETYPE):
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format.
"""
assert etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]), \
'Base partition book only supports homogeneous graph.'
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format."""
assert etype in (
DEFAULT_ETYPE,
DEFAULT_ETYPE[1],
), "Base partition book only supports homogeneous graph."
return ids
def nid2partid(self, nids, ntype=DEFAULT_NTYPE):
"""From global node IDs to partition IDs
"""
assert ntype == DEFAULT_NTYPE, 'Base partition book only supports homogeneous graph.'
"""From global node IDs to partition IDs"""
assert (
ntype == DEFAULT_NTYPE
), "Base partition book only supports homogeneous graph."
return F.gather_row(self._nid2partid, nids)
def eid2partid(self, eids, etype=DEFAULT_ETYPE):
"""From global edge IDs to partition IDs
"""
assert etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]), \
'Base partition book only supports homogeneous graph.'
"""From global edge IDs to partition IDs"""
assert etype in (
DEFAULT_ETYPE,
DEFAULT_ETYPE[1],
), "Base partition book only supports homogeneous graph."
return F.gather_row(self._eid2partid, eids)
def partid2nids(self, partid, ntype=DEFAULT_NTYPE):
"""From partition id to global node IDs
"""
assert ntype == DEFAULT_NTYPE, 'Base partition book only supports homogeneous graph.'
"""From partition id to global node IDs"""
assert (
ntype == DEFAULT_NTYPE
), "Base partition book only supports homogeneous graph."
return self._partid2nids[partid]
def partid2eids(self, partid, etype=DEFAULT_ETYPE):
"""From partition id to global edge IDs
"""
assert etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]), \
'Base partition book only supports homogeneous graph.'
"""From partition id to global edge IDs"""
assert etype in (
DEFAULT_ETYPE,
DEFAULT_ETYPE[1],
), "Base partition book only supports homogeneous graph."
return self._partid2eids[partid]
def nid2localnid(self, nids, partid, ntype=DEFAULT_NTYPE):
"""Get local node IDs within the given partition.
"""
assert ntype == DEFAULT_NTYPE, 'Base partition book only supports homogeneous graph.'
"""Get local node IDs within the given partition."""
assert (
ntype == DEFAULT_NTYPE
), "Base partition book only supports homogeneous graph."
if partid != self._part_id:
raise RuntimeError('Now GraphPartitionBook does not support \
getting remote tensor of nid2localnid.')
raise RuntimeError(
"Now GraphPartitionBook does not support \
getting remote tensor of nid2localnid."
)
return F.gather_row(self._nidg2l[partid], nids)
def eid2localeid(self, eids, partid, etype=DEFAULT_ETYPE):
"""Get the local edge ids within the given partition.
"""
assert etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]), \
'Base partition book only supports homogeneous graph.'
"""Get the local edge ids within the given partition."""
assert etype in (
DEFAULT_ETYPE,
DEFAULT_ETYPE[1],
), "Base partition book only supports homogeneous graph."
if partid != self._part_id:
raise RuntimeError('Now GraphPartitionBook does not support \
getting remote tensor of eid2localeid.')
raise RuntimeError(
"Now GraphPartitionBook does not support \
getting remote tensor of eid2localeid."
)
return F.gather_row(self._eidg2l[partid], eids)
@property
def partid(self):
"""Get the current partition ID
"""
"""Get the current partition ID"""
return self._part_id
@property
def ntypes(self):
"""Get the list of node types
"""
"""Get the list of node types"""
return [DEFAULT_NTYPE]
@property
def etypes(self):
"""Get the list of edge types
"""
"""Get the list of edge types"""
return [DEFAULT_ETYPE[1]]
@property
......@@ -698,9 +775,10 @@ class RangePartitionBook(GraphPartitionBook):
Single string format for keys of ``edge_map`` and ``etypes`` is deprecated.
``(str, str, str)`` will be the only format supported in the future.
"""
def __init__(self, part_id, num_parts, node_map, edge_map, ntypes, etypes):
assert part_id >= 0, 'part_id cannot be a negative number.'
assert num_parts > 0, 'num_parts must be greater than zero.'
assert part_id >= 0, "part_id cannot be a negative number."
assert num_parts > 0, "num_parts must be greater than zero."
self._partid = part_id
self._num_partitions = num_parts
self._ntypes = [None] * len(ntypes)
......@@ -711,12 +789,14 @@ class RangePartitionBook(GraphPartitionBook):
for ntype in ntypes:
ntype_id = ntypes[ntype]
self._ntypes[ntype_id] = ntype
assert all(ntype is not None for ntype in self._ntypes), \
"The node types have invalid IDs."
assert all(
ntype is not None for ntype in self._ntypes
), "The node types have invalid IDs."
for etype, etype_id in etypes.items():
if isinstance(etype, tuple):
assert len(etype) == 3, \
'Canonical etype should be in format of (str, str, str).'
assert (
len(etype) == 3
), "Canonical etype should be in format of (str, str, str)."
c_etype = etype
etype = etype[1]
self._etypes[etype_id] = etype
......@@ -730,11 +810,13 @@ class RangePartitionBook(GraphPartitionBook):
self._etype2canonical[etype] = c_etype
else:
dgl_warning(
"Etype with 'str' format is deprecated. Please use '(str, str, str)'.")
"Etype with 'str' format is deprecated. Please use '(str, str, str)'."
)
self._etypes[etype_id] = etype
self._canonical_etypes[etype_id] = None
assert all(etype is not None for etype in self._etypes), \
"The edge types have invalid IDs."
assert all(
etype is not None for etype in self._etypes
), "The edge types have invalid IDs."
# This stores the node ID ranges for each node type in each partition.
# The key is the node type, the value is a NumPy matrix with two columns, in which
......@@ -747,16 +829,20 @@ class RangePartitionBook(GraphPartitionBook):
self._typed_max_node_ids = {}
max_node_map = np.zeros((num_parts,), dtype=np.int64)
for key in node_map:
assert key in ntypes, 'Unexpected ntype: {}.'.format(key)
assert key in ntypes, "Unexpected ntype: {}.".format(key)
if not isinstance(node_map[key], np.ndarray):
node_map[key] = F.asnumpy(node_map[key])
assert node_map[key].shape == (num_parts, 2)
self._typed_nid_range[key] = node_map[key]
# This is used for per-node-type lookup.
self._typed_max_node_ids[key] = np.cumsum(self._typed_nid_range[key][:, 1]
- self._typed_nid_range[key][:, 0])
self._typed_max_node_ids[key] = np.cumsum(
self._typed_nid_range[key][:, 1]
- self._typed_nid_range[key][:, 0]
)
# This is used for homogeneous node ID lookup.
max_node_map = np.maximum(self._typed_nid_range[key][:, 1], max_node_map)
max_node_map = np.maximum(
self._typed_nid_range[key][:, 1], max_node_map
)
# This is a vector that indicates the last node ID in each partition.
# The ID is the global ID in the homogeneous representation.
self._max_node_ids = max_node_map
......@@ -767,16 +853,20 @@ class RangePartitionBook(GraphPartitionBook):
self._typed_max_edge_ids = {}
max_edge_map = np.zeros((num_parts,), dtype=np.int64)
for key in edge_map:
assert key in etypes, 'Unexpected etype: {}.'.format(key)
assert key in etypes, "Unexpected etype: {}.".format(key)
if not isinstance(edge_map[key], np.ndarray):
edge_map[key] = F.asnumpy(edge_map[key])
assert edge_map[key].shape == (num_parts, 2)
self._typed_eid_range[key] = edge_map[key]
# This is used for per-edge-type lookup.
self._typed_max_edge_ids[key] = np.cumsum(self._typed_eid_range[key][:, 1]
- self._typed_eid_range[key][:, 0])
self._typed_max_edge_ids[key] = np.cumsum(
self._typed_eid_range[key][:, 1]
- self._typed_eid_range[key][:, 0]
)
# This is used for homogeneous edge ID lookup.
max_edge_map = np.maximum(self._typed_eid_range[key][:, 1], max_edge_map)
max_edge_map = np.maximum(
self._typed_eid_range[key][:, 1], max_edge_map
)
# Similar to _max_node_ids
self._max_edge_ids = max_edge_map
......@@ -796,14 +886,13 @@ class RangePartitionBook(GraphPartitionBook):
num_edges = erange_end - erange_start
part_info = {}
part_info['machine_id'] = partid
part_info['num_nodes'] = int(num_nodes)
part_info['num_edges'] = int(num_edges)
part_info["machine_id"] = partid
part_info["num_nodes"] = int(num_nodes)
part_info["num_edges"] = int(num_edges)
self._partition_meta_data.append(part_info)
def shared_memory(self, graph_name):
"""Move data to shared memory.
"""
"""Move data to shared memory."""
# we need to store the nid ranges and eid ranges of different types in the order defined
# by type IDs.
nid_range = [None] * len(self.ntypes)
......@@ -817,31 +906,30 @@ class RangePartitionBook(GraphPartitionBook):
eid_range[i] = (c_etype, self._typed_eid_range[c_etype])
eid_range_pickle = list(pickle.dumps(eid_range))
self._meta = _move_metadata_to_shared_mem(graph_name,
0, # We don't need to provide the number of nodes
0, # We don't need to provide the number of edges
self._partid, self._num_partitions,
F.tensor(nid_range_pickle),
F.tensor(eid_range_pickle),
True)
self._meta = _move_metadata_to_shared_mem(
graph_name,
0, # We don't need to provide the number of nodes
0, # We don't need to provide the number of edges
self._partid,
self._num_partitions,
F.tensor(nid_range_pickle),
F.tensor(eid_range_pickle),
True,
)
def num_partitions(self):
"""Return the number of partitions.
"""
"""Return the number of partitions."""
return self._num_partitions
def _num_nodes(self, ntype=DEFAULT_NTYPE):
""" The total number of nodes
"""
"""The total number of nodes"""
if ntype == DEFAULT_NTYPE:
return int(self._max_node_ids[-1])
else:
return int(self._typed_max_node_ids[ntype][-1])
def _num_edges(self, etype=DEFAULT_ETYPE):
""" The total number of edges
"""
"""The total number of edges"""
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
return int(self._max_edge_ids[-1])
else:
......@@ -849,8 +937,7 @@ class RangePartitionBook(GraphPartitionBook):
return int(self._typed_max_edge_ids[c_etype][-1])
def metadata(self):
"""Return the partition meta data.
"""
"""Return the partition meta data."""
return self._partition_meta_data
def map_to_per_ntype(self, ids):
......@@ -868,67 +955,75 @@ class RangePartitionBook(GraphPartitionBook):
return self._eid_map(ids)
def map_to_homo_nid(self, ids, ntype):
"""Map per-node-type IDs to global node IDs in the homogeneous format.
"""
"""Map per-node-type IDs to global node IDs in the homogeneous format."""
ids = utils.toindex(ids).tousertensor()
partids = self.nid2partid(ids, ntype)
typed_max_nids = F.zerocopy_from_numpy(self._typed_max_node_ids[ntype])
end_diff = F.gather_row(typed_max_nids, partids) - ids
typed_nid_range = F.zerocopy_from_numpy(self._typed_nid_range[ntype][:, 1])
typed_nid_range = F.zerocopy_from_numpy(
self._typed_nid_range[ntype][:, 1]
)
return F.gather_row(typed_nid_range, partids) - end_diff
def map_to_homo_eid(self, ids, etype):
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format.
"""
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format."""
ids = utils.toindex(ids).tousertensor()
c_etype = self._to_canonical_etype(etype)
partids = self.eid2partid(ids, c_etype)
typed_max_eids = F.zerocopy_from_numpy(self._typed_max_edge_ids[c_etype])
typed_max_eids = F.zerocopy_from_numpy(
self._typed_max_edge_ids[c_etype]
)
end_diff = F.gather_row(typed_max_eids, partids) - ids
typed_eid_range = F.zerocopy_from_numpy(self._typed_eid_range[c_etype][:, 1])
typed_eid_range = F.zerocopy_from_numpy(
self._typed_eid_range[c_etype][:, 1]
)
return F.gather_row(typed_eid_range, partids) - end_diff
def nid2partid(self, nids, ntype=DEFAULT_NTYPE):
"""From global node IDs to partition IDs
"""
"""From global node IDs to partition IDs"""
nids = utils.toindex(nids)
if ntype == DEFAULT_NTYPE:
ret = np.searchsorted(self._max_node_ids, nids.tonumpy(), side='right')
ret = np.searchsorted(
self._max_node_ids, nids.tonumpy(), side="right"
)
else:
ret = np.searchsorted(self._typed_max_node_ids[ntype], nids.tonumpy(), side='right')
ret = np.searchsorted(
self._typed_max_node_ids[ntype], nids.tonumpy(), side="right"
)
ret = utils.toindex(ret)
return ret.tousertensor()
def eid2partid(self, eids, etype=DEFAULT_ETYPE):
"""From global edge IDs to partition IDs
"""
"""From global edge IDs to partition IDs"""
eids = utils.toindex(eids)
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
ret = np.searchsorted(self._max_edge_ids, eids.tonumpy(), side='right')
ret = np.searchsorted(
self._max_edge_ids, eids.tonumpy(), side="right"
)
else:
c_etype = self._to_canonical_etype(etype)
ret = np.searchsorted(self._typed_max_edge_ids[c_etype], eids.tonumpy(), side='right')
ret = np.searchsorted(
self._typed_max_edge_ids[c_etype], eids.tonumpy(), side="right"
)
ret = utils.toindex(ret)
return ret.tousertensor()
def partid2nids(self, partid, ntype=DEFAULT_NTYPE):
"""From partition ID to global node IDs
"""
"""From partition ID to global node IDs"""
# TODO do we need to cache it?
if ntype == DEFAULT_NTYPE:
start = self._max_node_ids[partid - 1] if partid > 0 else 0
end = self._max_node_ids[partid]
return F.arange(start, end)
else:
start = self._typed_max_node_ids[ntype][partid - 1] if partid > 0 else 0
start = (
self._typed_max_node_ids[ntype][partid - 1] if partid > 0 else 0
)
end = self._typed_max_node_ids[ntype][partid]
return F.arange(start, end)
def partid2eids(self, partid, etype=DEFAULT_ETYPE):
"""From partition ID to global edge IDs
"""
"""From partition ID to global edge IDs"""
# TODO do we need to cache it?
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
start = self._max_edge_ids[partid - 1] if partid > 0 else 0
......@@ -936,33 +1031,39 @@ class RangePartitionBook(GraphPartitionBook):
return F.arange(start, end)
else:
c_etype = self._to_canonical_etype(etype)
start = self._typed_max_edge_ids[c_etype][partid - 1] if partid > 0 else 0
start = (
self._typed_max_edge_ids[c_etype][partid - 1]
if partid > 0
else 0
)
end = self._typed_max_edge_ids[c_etype][partid]
return F.arange(start, end)
def nid2localnid(self, nids, partid, ntype=DEFAULT_NTYPE):
"""Get local node IDs within the given partition.
"""
"""Get local node IDs within the given partition."""
if partid != self._partid:
raise RuntimeError('Now RangePartitionBook does not support \
getting remote tensor of nid2localnid.')
raise RuntimeError(
"Now RangePartitionBook does not support \
getting remote tensor of nid2localnid."
)
nids = utils.toindex(nids)
nids = nids.tousertensor()
if ntype == DEFAULT_NTYPE:
start = self._max_node_ids[partid - 1] if partid > 0 else 0
else:
start = self._typed_max_node_ids[ntype][partid - 1] if partid > 0 else 0
start = (
self._typed_max_node_ids[ntype][partid - 1] if partid > 0 else 0
)
return nids - int(start)
def eid2localeid(self, eids, partid, etype=DEFAULT_ETYPE):
"""Get the local edge IDs within the given partition.
"""
"""Get the local edge IDs within the given partition."""
if partid != self._partid:
raise RuntimeError('Now RangePartitionBook does not support \
getting remote tensor of eid2localeid.')
raise RuntimeError(
"Now RangePartitionBook does not support \
getting remote tensor of eid2localeid."
)
eids = utils.toindex(eids)
eids = eids.tousertensor()
......@@ -970,26 +1071,26 @@ class RangePartitionBook(GraphPartitionBook):
start = self._max_edge_ids[partid - 1] if partid > 0 else 0
else:
c_etype = self._to_canonical_etype(etype)
start = self._typed_max_edge_ids[c_etype][partid - 1] if partid > 0 else 0
start = (
self._typed_max_edge_ids[c_etype][partid - 1]
if partid > 0
else 0
)
return eids - int(start)
@property
def partid(self):
"""Get the current partition ID.
"""
"""Get the current partition ID."""
return self._partid
@property
def ntypes(self):
"""Get the list of node types
"""
"""Get the list of node types"""
return self._ntypes
@property
def etypes(self):
"""Get the list of edge types
"""
"""Get the list of edge types"""
return self._etypes
@property
......@@ -1023,12 +1124,16 @@ class RangePartitionBook(GraphPartitionBook):
if ret is None:
raise DGLError('Edge type "{}" does not exist.'.format(etype))
if len(ret) == 0:
raise DGLError('Edge type "%s" is ambiguous. Please use canonical edge type '
'in the form of (srctype, etype, dsttype)' % etype)
raise DGLError(
'Edge type "%s" is ambiguous. Please use canonical edge type '
"in the form of (srctype, etype, dsttype)" % etype
)
return ret
NODE_PART_POLICY = 'node'
EDGE_PART_POLICY = 'edge'
NODE_PART_POLICY = "node"
EDGE_PART_POLICY = "edge"
class PartitionPolicy(object):
"""This defines a partition policy for a distributed tensor or distributed embedding.
......@@ -1047,11 +1152,14 @@ class PartitionPolicy(object):
partition_book : GraphPartitionBook
A graph partition book
"""
def __init__(self, policy_str, partition_book):
splits = policy_str.split(':')
splits = policy_str.split(":")
if len(splits) == 1:
assert policy_str in (EDGE_PART_POLICY, NODE_PART_POLICY), \
'policy_str must contain \'edge\' or \'node\'.'
assert policy_str in (
EDGE_PART_POLICY,
NODE_PART_POLICY,
), "policy_str must contain 'edge' or 'node'."
if NODE_PART_POLICY == policy_str:
policy_str = NODE_PART_POLICY + ":" + DEFAULT_NTYPE
else:
......@@ -1106,8 +1214,7 @@ class PartitionPolicy(object):
return self._partition_book
def get_data_name(self, name):
"""Get HeteroDataName
"""
"""Get HeteroDataName"""
is_node = NODE_PART_POLICY in self.policy_str
return HeteroDataName(is_node, self.type_name, name)
......@@ -1125,11 +1232,15 @@ class PartitionPolicy(object):
local ID tensor
"""
if EDGE_PART_POLICY in self.policy_str:
return self._partition_book.eid2localeid(id_tensor, self._part_id, self.type_name)
return self._partition_book.eid2localeid(
id_tensor, self._part_id, self.type_name
)
elif NODE_PART_POLICY in self.policy_str:
return self._partition_book.nid2localnid(id_tensor, self._part_id, self.type_name)
return self._partition_book.nid2localnid(
id_tensor, self._part_id, self.type_name
)
else:
raise RuntimeError('Cannot support policy: %s ' % self.policy_str)
raise RuntimeError("Cannot support policy: %s " % self.policy_str)
def to_partid(self, id_tensor):
"""Mapping global ID to partition ID.
......@@ -1149,7 +1260,7 @@ class PartitionPolicy(object):
elif NODE_PART_POLICY in self.policy_str:
return self._partition_book.nid2partid(id_tensor, self.type_name)
else:
raise RuntimeError('Cannot support policy: %s ' % self.policy_str)
raise RuntimeError("Cannot support policy: %s " % self.policy_str)
def get_part_size(self):
"""Get data size of current partition.
......@@ -1160,11 +1271,15 @@ class PartitionPolicy(object):
data size
"""
if EDGE_PART_POLICY in self.policy_str:
return len(self._partition_book.partid2eids(self._part_id, self.type_name))
return len(
self._partition_book.partid2eids(self._part_id, self.type_name)
)
elif NODE_PART_POLICY in self.policy_str:
return len(self._partition_book.partid2nids(self._part_id, self.type_name))
return len(
self._partition_book.partid2nids(self._part_id, self.type_name)
)
else:
raise RuntimeError('Cannot support policy: %s ' % self.policy_str)
raise RuntimeError("Cannot support policy: %s " % self.policy_str)
def get_size(self):
"""Get the full size of the data.
......@@ -1179,24 +1294,29 @@ class PartitionPolicy(object):
elif NODE_PART_POLICY in self.policy_str:
return self._partition_book._num_nodes(self.type_name)
else:
raise RuntimeError('Cannot support policy: %s ' % self.policy_str)
raise RuntimeError("Cannot support policy: %s " % self.policy_str)
class NodePartitionPolicy(PartitionPolicy):
'''Partition policy for nodes.
'''
"""Partition policy for nodes."""
def __init__(self, partition_book, ntype=DEFAULT_NTYPE):
super(NodePartitionPolicy, self).__init__(
NODE_PART_POLICY + ':' + ntype, partition_book)
NODE_PART_POLICY + ":" + ntype, partition_book
)
class EdgePartitionPolicy(PartitionPolicy):
'''Partition policy for edges.
'''
"""Partition policy for edges."""
def __init__(self, partition_book, etype=DEFAULT_ETYPE):
super(EdgePartitionPolicy, self).__init__(
EDGE_PART_POLICY + ':' + str(etype), partition_book)
EDGE_PART_POLICY + ":" + str(etype), partition_book
)
class HeteroDataName(object):
''' The data name in a heterogeneous graph.
"""The data name in a heterogeneous graph.
A unique data name has three components:
* indicate it's node data or edge data.
......@@ -1211,7 +1331,8 @@ class HeteroDataName(object):
The type of the node/edge.
data_name : str
The name of the data.
'''
"""
def __init__(self, is_node, entity_type, data_name):
self._policy = NODE_PART_POLICY if is_node else EDGE_PART_POLICY
self._entity_type = entity_type
......@@ -1219,41 +1340,38 @@ class HeteroDataName(object):
@property
def policy_str(self):
''' concatenate policy and entity type into string
'''
return self._policy + ':' + str(self.get_type())
"""concatenate policy and entity type into string"""
return self._policy + ":" + str(self.get_type())
def is_node(self):
''' Is this the name of node data
'''
"""Is this the name of node data"""
return NODE_PART_POLICY in self.policy_str
def is_edge(self):
''' Is this the name of edge data
'''
"""Is this the name of edge data"""
return EDGE_PART_POLICY in self.policy_str
def get_type(self):
''' The type of the node/edge.
"""The type of the node/edge.
This is only meaningful in a heterogeneous graph.
In homogeneous graph, type is '_N' for a node and '_E' for an edge.
'''
"""
return self._entity_type
def get_name(self):
''' The name of the data.
'''
"""The name of the data."""
return self.data_name
def __str__(self):
''' The full name of the data.
"""The full name of the data.
The full name is used as the key in the KVStore.
'''
return self.policy_str + ':' + self.data_name
"""
return self.policy_str + ":" + self.data_name
def parse_hetero_data_name(name):
'''Parse data name and create HeteroDataName.
"""Parse data name and create HeteroDataName.
The data name has a specialized format. We can parse the name to determine if
it's node data or edge data, node/edge type and its actual name. The data name
......@@ -1267,9 +1385,15 @@ def parse_hetero_data_name(name):
Returns
-------
HeteroDataName
'''
names = name.split(':')
assert len(names) == 3, '{} is not a valid heterograph data name'.format(name)
assert names[0] in (NODE_PART_POLICY, EDGE_PART_POLICY), \
'{} is not a valid heterograph data name'.format(name)
return HeteroDataName(names[0] == NODE_PART_POLICY, _str_to_tuple(names[1]), names[2])
"""
names = name.split(":")
assert len(names) == 3, "{} is not a valid heterograph data name".format(
name
)
assert names[0] in (
NODE_PART_POLICY,
EDGE_PART_POLICY,
), "{} is not a valid heterograph data name".format(name)
return HeteroDataName(
names[0] == NODE_PART_POLICY, _str_to_tuple(names[1]), names[2]
)
"""A set of graph services of getting subgraphs from DistGraph"""
from collections import namedtuple
import numpy as np
from .rpc import Request, Response, send_requests_to_machine, recv_responses
from ..sampling import sample_neighbors as local_sample_neighbors
from .. import backend as F
from ..base import EID, NID
from ..convert import graph, heterograph
from ..sampling import sample_etype_neighbors as local_sample_etype_neighbors
from ..sampling import sample_neighbors as local_sample_neighbors
from ..subgraph import in_subgraph as local_in_subgraph
from .rpc import register_service
from ..convert import graph, heterograph
from ..base import NID, EID
from ..utils import toindex
from .. import backend as F
from .rpc import (
Request,
Response,
recv_responses,
register_service,
send_requests_to_machine,
)
__all__ = [
'sample_neighbors', 'sample_etype_neighbors',
'in_subgraph', 'find_edges'
"sample_neighbors",
"sample_etype_neighbors",
"in_subgraph",
"find_edges",
]
SAMPLING_SERVICE_ID = 6657
......@@ -24,6 +32,7 @@ OUTDEGREE_SERVICE_ID = 6660
INDEGREE_SERVICE_ID = 6661
ETYPE_SAMPLING_SERVICE_ID = 6662
class SubgraphResponse(Response):
"""The response for sampling and in_subgraph"""
......@@ -38,6 +47,7 @@ class SubgraphResponse(Response):
def __getstate__(self):
return self.global_src, self.global_dst, self.global_eids
class FindEdgeResponse(Response):
"""The response for sampling and in_subgraph"""
......@@ -52,8 +62,11 @@ class FindEdgeResponse(Response):
def __getstate__(self):
return self.global_src, self.global_dst, self.order_id
def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace):
""" Sample from local partition.
def _sample_neighbors(
local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace
):
"""Sample from local partition.
The input nodes use global IDs. We need to map the global node IDs to local node IDs,
perform sampling and map the sampled results to the global IDs space again.
......@@ -64,17 +77,35 @@ def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, pr
local_ids = F.astype(local_ids, local_g.idtype)
# local_ids = self.seed_nodes
sampled_graph = local_sample_neighbors(
local_g, local_ids, fan_out, edge_dir, prob, replace, _dist_training=True)
local_g,
local_ids,
fan_out,
edge_dir,
prob,
replace,
_dist_training=True,
)
global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges()
global_src, global_dst = F.gather_row(global_nid_mapping, src), \
F.gather_row(global_nid_mapping, dst)
global_src, global_dst = F.gather_row(
global_nid_mapping, src
), F.gather_row(global_nid_mapping, dst)
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids
def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field,
fan_out, edge_dir, prob, replace, etype_sorted=False):
""" Sample from local partition.
def _sample_etype_neighbors(
local_g,
partition_book,
seed_nodes,
etype_field,
fan_out,
edge_dir,
prob,
replace,
etype_sorted=False,
):
"""Sample from local partition.
The input nodes use global IDs. We need to map the global node IDs to local node IDs,
perform sampling and map the sampled results to the global IDs space again.
......@@ -85,18 +116,28 @@ def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field,
local_ids = F.astype(local_ids, local_g.idtype)
sampled_graph = local_sample_etype_neighbors(
local_g, local_ids, etype_field, fan_out, edge_dir, prob, replace,
etype_sorted=etype_sorted, _dist_training=True)
local_g,
local_ids,
etype_field,
fan_out,
edge_dir,
prob,
replace,
etype_sorted=etype_sorted,
_dist_training=True,
)
global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges()
global_src, global_dst = F.gather_row(global_nid_mapping, src), \
F.gather_row(global_nid_mapping, dst)
global_src, global_dst = F.gather_row(
global_nid_mapping, src
), F.gather_row(global_nid_mapping, dst)
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids
def _find_edges(local_g, partition_book, seed_edges):
"""Given an edge ID array, return the source
and destination node ID array ``s`` and ``d`` in the local partition.
and destination node ID array ``s`` and ``d`` in the local partition.
"""
local_eids = partition_book.eid2localeid(seed_edges, partition_book.partid)
local_eids = F.astype(local_eids, local_g.idtype)
......@@ -106,22 +147,23 @@ def _find_edges(local_g, partition_book, seed_edges):
global_dst = global_nid_mapping[local_dst]
return global_src, global_dst
def _in_degrees(local_g, partition_book, n):
"""Get in-degree of the nodes in the local partition.
"""
"""Get in-degree of the nodes in the local partition."""
local_nids = partition_book.nid2localnid(n, partition_book.partid)
local_nids = F.astype(local_nids, local_g.idtype)
return local_g.in_degrees(local_nids)
def _out_degrees(local_g, partition_book, n):
"""Get out-degree of the nodes in the local partition.
"""
"""Get out-degree of the nodes in the local partition."""
local_nids = partition_book.nid2localnid(n, partition_book.partid)
local_nids = F.astype(local_nids, local_g.idtype)
return local_g.out_degrees(local_nids)
def _in_subgraph(local_g, partition_book, seed_nodes):
""" Get in subgraph from local partition.
"""Get in subgraph from local partition.
The input nodes use global IDs. We need to map the global node IDs to local node IDs,
get in-subgraph and map the sampled results to the global IDs space again.
......@@ -142,7 +184,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
class SamplingRequest(Request):
"""Sampling Request"""
def __init__(self, nodes, fan_out, edge_dir='in', prob=None, replace=False):
def __init__(self, nodes, fan_out, edge_dir="in", prob=None, replace=False):
self.seed_nodes = nodes
self.edge_dir = edge_dir
self.prob = prob
......@@ -150,25 +192,51 @@ class SamplingRequest(Request):
self.fan_out = fan_out
def __setstate__(self, state):
self.seed_nodes, self.edge_dir, self.prob, self.replace, self.fan_out = state
(
self.seed_nodes,
self.edge_dir,
self.prob,
self.replace,
self.fan_out,
) = state
def __getstate__(self):
return self.seed_nodes, self.edge_dir, self.prob, self.replace, self.fan_out
return (
self.seed_nodes,
self.edge_dir,
self.prob,
self.replace,
self.fan_out,
)
def process_request(self, server_state):
local_g = server_state.graph
partition_book = server_state.partition_book
global_src, global_dst, global_eids = _sample_neighbors(local_g, partition_book,
self.seed_nodes,
self.fan_out, self.edge_dir,
self.prob, self.replace)
global_src, global_dst, global_eids = _sample_neighbors(
local_g,
partition_book,
self.seed_nodes,
self.fan_out,
self.edge_dir,
self.prob,
self.replace,
)
return SubgraphResponse(global_src, global_dst, global_eids)
class SamplingRequestEtype(Request):
"""Sampling Request"""
def __init__(self, nodes, etype_field, fan_out, edge_dir='in',
prob=None, replace=False, etype_sorted=True):
def __init__(
self,
nodes,
etype_field,
fan_out,
edge_dir="in",
prob=None,
replace=False,
etype_sorted=True,
):
self.seed_nodes = nodes
self.edge_dir = edge_dir
self.prob = prob
......@@ -178,27 +246,44 @@ class SamplingRequestEtype(Request):
self.etype_sorted = etype_sorted
def __setstate__(self, state):
self.seed_nodes, self.edge_dir, self.prob, self.replace, \
self.fan_out, self.etype_field, self.etype_sorted = state
(
self.seed_nodes,
self.edge_dir,
self.prob,
self.replace,
self.fan_out,
self.etype_field,
self.etype_sorted,
) = state
def __getstate__(self):
return self.seed_nodes, self.edge_dir, self.prob, self.replace, \
self.fan_out, self.etype_field, self.etype_sorted
return (
self.seed_nodes,
self.edge_dir,
self.prob,
self.replace,
self.fan_out,
self.etype_field,
self.etype_sorted,
)
def process_request(self, server_state):
local_g = server_state.graph
partition_book = server_state.partition_book
global_src, global_dst, global_eids = _sample_etype_neighbors(local_g,
partition_book,
self.seed_nodes,
self.etype_field,
self.fan_out,
self.edge_dir,
self.prob,
self.replace,
self.etype_sorted)
global_src, global_dst, global_eids = _sample_etype_neighbors(
local_g,
partition_book,
self.seed_nodes,
self.etype_field,
self.fan_out,
self.edge_dir,
self.prob,
self.replace,
self.etype_sorted,
)
return SubgraphResponse(global_src, global_dst, global_eids)
class EdgesRequest(Request):
"""Edges Request"""
......@@ -215,10 +300,13 @@ class EdgesRequest(Request):
def process_request(self, server_state):
local_g = server_state.graph
partition_book = server_state.partition_book
global_src, global_dst = _find_edges(local_g, partition_book, self.edge_ids)
global_src, global_dst = _find_edges(
local_g, partition_book, self.edge_ids
)
return FindEdgeResponse(global_src, global_dst, self.order_id)
class InDegreeRequest(Request):
"""In-degree Request"""
......@@ -239,6 +327,7 @@ class InDegreeRequest(Request):
return InDegreeResponse(deg, self.order_id)
class InDegreeResponse(Response):
"""The response for in-degree"""
......@@ -252,6 +341,7 @@ class InDegreeResponse(Response):
def __getstate__(self):
return self.val, self.order_id
class OutDegreeRequest(Request):
"""Out-degree Request"""
......@@ -272,6 +362,7 @@ class OutDegreeRequest(Request):
return OutDegreeResponse(deg, self.order_id)
class OutDegreeResponse(Response):
"""The response for out-degree"""
......@@ -285,6 +376,7 @@ class OutDegreeResponse(Response):
def __getstate__(self):
return self.val, self.order_id
class InSubgraphRequest(Request):
"""InSubgraph Request"""
......@@ -300,8 +392,9 @@ class InSubgraphRequest(Request):
def process_request(self, server_state):
local_g = server_state.graph
partition_book = server_state.partition_book
global_src, global_dst, global_eids = _in_subgraph(local_g, partition_book,
self.seed_nodes)
global_src, global_dst, global_eids = _in_subgraph(
local_g, partition_book, self.seed_nodes
)
return SubgraphResponse(global_src, global_dst, global_eids)
......@@ -326,10 +419,14 @@ def merge_graphs(res_list, num_nodes):
g.edata[EID] = eid_tensor
return g
LocalSampledGraph = namedtuple('LocalSampledGraph', 'global_src global_dst global_eids')
LocalSampledGraph = namedtuple(
"LocalSampledGraph", "global_src global_dst global_eids"
)
def _distributed_access(g, nodes, issue_remote_req, local_access):
'''A routine that fetches local neighborhood of nodes from the distributed graph.
"""A routine that fetches local neighborhood of nodes from the distributed graph.
The local neighborhood of some nodes are stored in the local machine and the other
nodes have their neighborhood on remote machines. This code will issue remote
......@@ -352,7 +449,7 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
-------
DGLHeteroGraph
The subgraph that contains the neighborhoods of all input nodes.
'''
"""
req_list = []
partition_book = g.get_partition_book()
nodes = toindex(nodes).tousertensor()
......@@ -379,7 +476,9 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
# sample neighbors for the nodes in the local partition.
res_list = []
if local_nids is not None:
src, dst, eids = local_access(g.local_partition, partition_book, local_nids)
src, dst, eids = local_access(
g.local_partition, partition_book, local_nids
)
res_list.append(LocalSampledGraph(src, dst, eids))
# receive responses from remote machines.
......@@ -390,13 +489,18 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
sampled_graph = merge_graphs(res_list, g.number_of_nodes())
return sampled_graph
def _frontier_to_heterogeneous_graph(g, frontier, gpb):
# We need to handle empty frontiers correctly.
if frontier.number_of_edges() == 0:
data_dict = {etype: (np.zeros(0), np.zeros(0)) for etype in g.canonical_etypes}
return heterograph(data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype)
data_dict = {
etype: (np.zeros(0), np.zeros(0)) for etype in g.canonical_etypes
}
return heterograph(
data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype,
)
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
src, dst = frontier.edges()
......@@ -413,19 +517,32 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
canonical_etype = g.canonical_etypes[etid]
type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0:
data_dict[canonical_etype] = (F.boolean_mask(src, type_idx), \
F.boolean_mask(dst, type_idx))
data_dict[canonical_etype] = (
F.boolean_mask(src, type_idx),
F.boolean_mask(dst, type_idx),
)
edge_ids[etype] = F.boolean_mask(eid, type_idx)
hg = heterograph(data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype)
hg = heterograph(
data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype,
)
for etype in edge_ids:
hg.edges[etype].data[EID] = edge_ids[etype]
return hg
def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in',
prob=None, replace=False, etype_sorted=True):
def sample_etype_neighbors(
g,
nodes,
etype_field,
fanout,
edge_dir="in",
prob=None,
replace=False,
etype_sorted=True,
):
"""Sample from the neighbors of the given nodes from a distributed graph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
......@@ -495,28 +612,50 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in',
if isinstance(nodes, dict):
homo_nids = []
for ntype in nodes.keys():
assert ntype in g.ntypes, \
'The sampled node type {} does not exist in the input graph'.format(ntype)
assert (
ntype in g.ntypes
), "The sampled node type {} does not exist in the input graph".format(
ntype
)
if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype]
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
nodes = F.cat(homo_nids, 0)
def issue_remote_req(node_ids):
return SamplingRequestEtype(node_ids, etype_field, fanout, edge_dir=edge_dir,
prob=prob, replace=replace, etype_sorted=etype_sorted)
return SamplingRequestEtype(
node_ids,
etype_field,
fanout,
edge_dir=edge_dir,
prob=prob,
replace=replace,
etype_sorted=etype_sorted,
)
def local_access(local_g, partition_book, local_nids):
return _sample_etype_neighbors(local_g, partition_book, local_nids,
etype_field, fanout, edge_dir, prob, replace,
etype_sorted=etype_sorted)
return _sample_etype_neighbors(
local_g,
partition_book,
local_nids,
etype_field,
fanout,
edge_dir,
prob,
replace,
etype_sorted=etype_sorted,
)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else:
return frontier
def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
"""Sample from the neighbors of the given nodes from a distributed graph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
......@@ -570,7 +709,9 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
assert isinstance(nodes, dict)
homo_nids = []
for ntype in nodes:
assert ntype in g.ntypes, 'The sampled node type does not exist in the input graph'
assert (
ntype in g.ntypes
), "The sampled node type does not exist in the input graph"
if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype]
else:
......@@ -582,17 +723,22 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
nodes = list(nodes.values())[0]
def issue_remote_req(node_ids):
return SamplingRequest(node_ids, fanout, edge_dir=edge_dir,
prob=prob, replace=replace)
return SamplingRequest(
node_ids, fanout, edge_dir=edge_dir, prob=prob, replace=replace
)
def local_access(local_g, partition_book, local_nids):
return _sample_neighbors(local_g, partition_book, local_nids,
fanout, edge_dir, prob, replace)
return _sample_neighbors(
local_g, partition_book, local_nids, fanout, edge_dir, prob, replace
)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else:
return frontier
def _distributed_edge_access(g, edges, issue_remote_req, local_access):
"""A routine that fetches local edges from distributed graph.
......@@ -626,7 +772,7 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access):
local_eids = None
reorder_idx = []
for pid in range(partition_book.num_partitions()):
mask = (partition_id == pid)
mask = partition_id == pid
edge_id = F.boolean_mask(edges, mask)
reorder_idx.append(F.nonzero_1d(mask))
if pid == partition_book.partid and g.local_partition is not None:
......@@ -646,8 +792,12 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access):
dst_ids = F.zeros_like(edges)
if local_eids is not None:
src, dst = local_access(g.local_partition, partition_book, local_eids)
src_ids = F.scatter_row(src_ids, reorder_idx[partition_book.partid], src)
dst_ids = F.scatter_row(dst_ids, reorder_idx[partition_book.partid], dst)
src_ids = F.scatter_row(
src_ids, reorder_idx[partition_book.partid], src
)
dst_ids = F.scatter_row(
dst_ids, reorder_idx[partition_book.partid], dst
)
# receive responses from remote machines.
if msgseq2pos is not None:
......@@ -659,8 +809,9 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access):
dst_ids = F.scatter_row(dst_ids, reorder_idx[result.order_id], dst)
return src_ids, dst_ids
def find_edges(g, edge_ids):
""" Given an edge ID array, return the source and destination
"""Given an edge ID array, return the source and destination
node ID array ``s`` and ``d`` from a distributed graph.
``s[i]`` and ``d[i]`` are source and destination node ID for
edge ``eid[i]``.
......@@ -679,12 +830,16 @@ def find_edges(g, edge_ids):
tensor
The destination node ID array.
"""
def issue_remote_req(edge_ids, order_id):
return EdgesRequest(edge_ids, order_id)
def local_access(local_g, partition_book, edge_ids):
return _find_edges(local_g, partition_book, edge_ids)
return _distributed_edge_access(g, edge_ids, issue_remote_req, local_access)
def in_subgraph(g, nodes):
"""Return the subgraph induced on the inbound edges of the given nodes.
......@@ -713,14 +868,20 @@ def in_subgraph(g, nodes):
edge ID via ``dgl.EID`` edge features of the subgraph.
"""
if isinstance(nodes, dict):
assert len(nodes) == 1, 'The distributed in_subgraph only supports one node type for now.'
assert (
len(nodes) == 1
), "The distributed in_subgraph only supports one node type for now."
nodes = list(nodes.values())[0]
def issue_remote_req(node_ids):
return InSubgraphRequest(node_ids)
def local_access(local_g, partition_book, local_nids):
return _in_subgraph(local_g, partition_book, local_nids)
return _distributed_access(g, nodes, issue_remote_req, local_access)
def _distributed_get_node_property(g, n, issue_remote_req, local_access):
req_list = []
partition_book = g.get_partition_book()
......@@ -729,7 +890,7 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access):
local_nids = None
reorder_idx = []
for pid in range(partition_book.num_partitions()):
mask = (partition_id == pid)
mask = partition_id == pid
nid = F.boolean_mask(n, mask)
reorder_idx.append(F.nonzero_1d(mask))
if pid == partition_book.partid and g.local_partition is not None:
......@@ -751,7 +912,9 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access):
shape = list(F.shape(local_vals))
shape[0] = len(n)
vals = F.zeros(shape, F.dtype(local_vals), F.cpu())
vals = F.scatter_row(vals, reorder_idx[partition_book.partid], local_vals)
vals = F.scatter_row(
vals, reorder_idx[partition_book.partid], local_vals
)
# receive responses from remote machines.
if msgseq2pos is not None:
......@@ -765,27 +928,36 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access):
vals = F.scatter_row(vals, reorder_idx[result.order_id], val)
return vals
def in_degrees(g, v):
'''Get in-degrees
'''
"""Get in-degrees"""
def issue_remote_req(v, order_id):
return InDegreeRequest(v, order_id)
def local_access(local_g, partition_book, v):
return _in_degrees(local_g, partition_book, v)
return _distributed_get_node_property(g, v, issue_remote_req, local_access)
def out_degrees(g, u):
'''Get out-degrees
'''
"""Get out-degrees"""
def issue_remote_req(u, order_id):
return OutDegreeRequest(u, order_id)
def local_access(local_g, partition_book, u):
return _out_degrees(local_g, partition_book, u)
return _distributed_get_node_property(g, u, issue_remote_req, local_access)
register_service(SAMPLING_SERVICE_ID, SamplingRequest, SubgraphResponse)
register_service(EDGES_SERVICE_ID, EdgesRequest, FindEdgeResponse)
register_service(INSUBGRAPH_SERVICE_ID, InSubgraphRequest, SubgraphResponse)
register_service(OUTDEGREE_SERVICE_ID, OutDegreeRequest, OutDegreeResponse)
register_service(INDEGREE_SERVICE_ID, InDegreeRequest, InDegreeResponse)
register_service(ETYPE_SAMPLING_SERVICE_ID, SamplingRequestEtype, SubgraphResponse)
register_service(
ETYPE_SAMPLING_SERVICE_ID, SamplingRequestEtype, SubgraphResponse
)
"""dgl distributed.optims."""
import importlib
import sys
import os
import sys
from ...backend import backend_name
from ...utils import expand_as_pair
def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__)
mod = importlib.import_module(".%s" % mod_name, __name__)
thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items():
setattr(thismod, api, obj)
_load_backend(backend_name)
"""dgl distributed.optims."""
import importlib
import sys
import os
import sys
from ...backend import backend_name
from ...utils import expand_as_pair
def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__)
mod = importlib.import_module(".%s" % mod_name, __name__)
thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items():
setattr(thismod, api, obj)
_load_backend(backend_name)
......@@ -3,6 +3,7 @@
import torch as th
import torch.distributed as dist
def alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):
"""Each process scatters list of input tensors to all processes in a cluster
and return gathered list of tensors in output list. The tensors should have the same shape.
......@@ -18,9 +19,14 @@ def alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):
input_tensor_list : List of tensor
The tensors to exchange
"""
input_tensor_list = [tensor.to(th.device('cpu')) for tensor in input_tensor_list]
input_tensor_list = [
tensor.to(th.device("cpu")) for tensor in input_tensor_list
]
for i in range(world_size):
dist.scatter(output_tensor_list[i], input_tensor_list if i == rank else [], src=i)
dist.scatter(
output_tensor_list[i], input_tensor_list if i == rank else [], src=i
)
def alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list):
"""Each process scatters list of input tensors to all processes in a cluster
......@@ -42,9 +48,11 @@ def alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list):
senders = []
for i in range(world_size):
if i == rank:
output_tensor_list[i] = input_tensor_list[i].to(th.device('cpu'))
output_tensor_list[i] = input_tensor_list[i].to(th.device("cpu"))
else:
sender = dist.isend(input_tensor_list[i].to(th.device('cpu')), dst=i)
sender = dist.isend(
input_tensor_list[i].to(th.device("cpu")), dst=i
)
senders.append(sender)
for i in range(world_size):
......
......@@ -5,6 +5,7 @@ some work as trainers.
"""
import os
import numpy as np
from . import rpc
......@@ -12,10 +13,12 @@ from . import rpc
REGISTER_ROLE = 700001
REG_ROLE_MSG = "Register_Role"
class RegisterRoleResponse(rpc.Response):
"""Send a confirmation signal (just a short string message)
of RegisterRoleRequest to client.
"""
def __init__(self, msg):
self.msg = msg
......@@ -25,6 +28,7 @@ class RegisterRoleResponse(rpc.Response):
def __setstate__(self, state):
self.msg = state
class RegisterRoleRequest(rpc.Request):
"""Send client id and role to server
......@@ -35,6 +39,7 @@ class RegisterRoleRequest(rpc.Request):
role : str
role of client
"""
def __init__(self, client_id, machine_id, role):
self.client_id = client_id
self.machine_id = machine_id
......@@ -53,7 +58,9 @@ class RegisterRoleRequest(rpc.Request):
if self.role not in role:
role[self.role] = set()
if kv_store is not None:
barrier_count = kv_store.barrier_count.setdefault(self.group_id, {})
barrier_count = kv_store.barrier_count.setdefault(
self.group_id, {}
)
barrier_count[self.role] = 0
role[self.role].add((self.client_id, self.machine_id))
total_count = 0
......@@ -67,11 +74,14 @@ class RegisterRoleRequest(rpc.Request):
return res_list
return None
GET_ROLE = 700002
GET_ROLE_MSG = "Get_Role"
class GetRoleResponse(rpc.Response):
"""Send the roles of all client processes"""
def __init__(self, role):
self.role = role
self.msg = GET_ROLE_MSG
......@@ -82,8 +92,10 @@ class GetRoleResponse(rpc.Response):
def __setstate__(self, state):
self.role, self.msg = state
class GetRoleRequest(rpc.Request):
"""Send a request to get the roles of all client processes."""
def __init__(self):
self.msg = GET_ROLE_MSG
self.group_id = rpc.get_group_id()
......@@ -97,6 +109,7 @@ class GetRoleRequest(rpc.Request):
def process_request(self, server_state):
return GetRoleResponse(server_state.roles[self.group_id])
# The key is role, the value is a dict of mapping RPC rank to a rank within the role.
PER_ROLE_RANK = {}
......@@ -109,6 +122,7 @@ CUR_ROLE = None
IS_STANDALONE = False
def init_role(role):
"""Initialize the role of the current process.
......@@ -128,10 +142,10 @@ def init_role(role):
global GLOBAL_RANK
global IS_STANDALONE
if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
if role == 'default':
if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone":
if role == "default":
GLOBAL_RANK[0] = 0
PER_ROLE_RANK['default'] = {0:0}
PER_ROLE_RANK["default"] = {0: 0}
IS_STANDALONE = True
return
......@@ -160,7 +174,7 @@ def init_role(role):
global_rank = 0
# We want to ensure that the global rank of the trainer process starts from 0.
role_names = ['default']
role_names = ["default"]
for role_name in response.role:
if role_name not in role_names:
role_names.append(role_name)
......@@ -185,6 +199,7 @@ def init_role(role):
PER_ROLE_RANK[role_name][client_id] = per_role_rank
per_role_rank += 1
def get_global_rank():
"""Get the global rank
......@@ -196,6 +211,7 @@ def get_global_rank():
else:
return GLOBAL_RANK[rpc.get_rank()]
def get_rank(role):
"""Get the role-specific rank"""
if IS_STANDALONE:
......@@ -203,25 +219,29 @@ def get_rank(role):
else:
return PER_ROLE_RANK[role][rpc.get_rank()]
def get_trainer_rank():
"""Get the rank of the current trainer process.
This function can only be called in the trainer process. It will result in
an error if it's called in the process of other roles.
"""
assert CUR_ROLE == 'default'
assert CUR_ROLE == "default"
if IS_STANDALONE:
return 0
else:
return PER_ROLE_RANK['default'][rpc.get_rank()]
return PER_ROLE_RANK["default"][rpc.get_rank()]
def get_role():
"""Get the role of the current process"""
return CUR_ROLE
def get_num_trainers():
"""Get the number of trainer processes"""
return len(PER_ROLE_RANK['default'])
return len(PER_ROLE_RANK["default"])
rpc.register_service(REGISTER_ROLE, RegisterRoleRequest, RegisterRoleResponse)
rpc.register_service(GET_ROLE, GetRoleRequest, GetRoleResponse)
"""RPC components. They are typically functions or utilities used by both
server and clients."""
import os
import abc
import os
import pickle
import random
import numpy as np
from .constants import SERVER_EXIT, SERVER_KEEP_ALIVE
import numpy as np
from .._ffi.object import register_object, ObjectBase
from .. import backend as F
from .._ffi.function import _init_api
from .._ffi.object import ObjectBase, register_object
from ..base import DGLError
from .. import backend as F
from .constants import SERVER_EXIT, SERVER_KEEP_ALIVE
__all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \
'wait_for_senders', 'connect_receiver', 'read_ip_config', 'get_group_id', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', 'DistConnectError', \
'get_num_client', 'set_num_client', 'client_barrier', 'copy_data_to_shared_memory']
__all__ = [
"set_rank",
"get_rank",
"Request",
"Response",
"register_service",
"create_sender",
"create_receiver",
"finalize_sender",
"finalize_receiver",
"wait_for_senders",
"connect_receiver",
"read_ip_config",
"get_group_id",
"get_num_machines",
"set_num_machines",
"get_machine_id",
"set_machine_id",
"send_request",
"recv_request",
"send_response",
"recv_response",
"remote_call",
"send_request_to_machine",
"remote_call_to_machine",
"fast_pull",
"DistConnectError",
"get_num_client",
"set_num_client",
"client_barrier",
"copy_data_to_shared_memory",
]
REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {}
......@@ -27,6 +52,7 @@ SERVICE_ID_TO_PROPERTY = {}
DEFUALT_PORT = 30050
def read_ip_config(filename, num_servers):
"""Read network configuration information of server from file.
......@@ -75,13 +101,15 @@ def read_ip_config(filename, num_servers):
6:[3, '172.31.30.180', 30050, 2],
7:[3, '172.31.30.180', 30051, 2]}
"""
assert len(filename) > 0, 'filename cannot be empty.'
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
assert len(filename) > 0, "filename cannot be empty."
assert num_servers > 0, (
"num_servers (%d) must be a positive number." % num_servers
)
server_namebook = {}
try:
server_id = 0
machine_id = 0
lines = [line.rstrip('\n') for line in open(filename)]
lines = [line.rstrip("\n") for line in open(filename)]
for line in lines:
result = line.split()
if len(result) == 2:
......@@ -89,21 +117,27 @@ def read_ip_config(filename, num_servers):
elif len(result) == 1:
port = DEFUALT_PORT
else:
raise RuntimeError('length of result can only be 1 or 2.')
raise RuntimeError("length of result can only be 1 or 2.")
ip_addr = result[0]
for s_count in range(num_servers):
server_namebook[server_id] = [machine_id, ip_addr, port+s_count, num_servers]
server_namebook[server_id] = [
machine_id,
ip_addr,
port + s_count,
num_servers,
]
server_id += 1
machine_id += 1
except RuntimeError:
print("Error: data format on each line should be: [ip] [port]")
return server_namebook
def reset():
"""Reset the rpc context
"""
"""Reset the rpc context"""
_CAPI_DGLRPCReset()
def create_sender(max_queue_size, net_type):
"""Create rpc sender of this process.
......@@ -114,9 +148,10 @@ def create_sender(max_queue_size, net_type):
net_type : str
Networking type. Current options are: 'socket', 'tensorpipe'.
"""
max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0'))
max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0"))
_CAPI_DGLRPCCreateSender(int(max_queue_size), net_type, max_thread_count)
def create_receiver(max_queue_size, net_type):
"""Create rpc receiver of this process.
......@@ -127,19 +162,20 @@ def create_receiver(max_queue_size, net_type):
net_type : str
Networking type. Current options are: 'socket', 'tensorpipe'.
"""
max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0'))
max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0"))
_CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type, max_thread_count)
def finalize_sender():
"""Finalize rpc sender of this process.
"""
"""Finalize rpc sender of this process."""
_CAPI_DGLRPCFinalizeSender()
def finalize_receiver():
"""Finalize rpc receiver of this process.
"""
"""Finalize rpc receiver of this process."""
_CAPI_DGLRPCFinalizeReceiver()
def wait_for_senders(ip_addr, port, num_senders, blocking=True):
"""Wait all of the senders' connections.
......@@ -158,6 +194,7 @@ def wait_for_senders(ip_addr, port, num_senders, blocking=True):
"""
_CAPI_DGLRPCWaitForSenders(ip_addr, int(port), int(num_senders), blocking)
def connect_receiver(ip_addr, port, recv_id, group_id=-1):
"""Connect to target receiver
......@@ -170,11 +207,14 @@ def connect_receiver(ip_addr, port, recv_id, group_id=-1):
recv_id : int
receiver's ID
"""
target_id = recv_id if group_id == -1 else register_client(recv_id, group_id)
target_id = (
recv_id if group_id == -1 else register_client(recv_id, group_id)
)
if target_id < 0:
raise DGLError("Invalid target id: {}".format(target_id))
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id))
def connect_receiver_finalize(max_try_times):
"""Finalize the action to connect to receivers. Make sure that either all connections are
successfully established or connection fails.
......@@ -189,6 +229,7 @@ def connect_receiver_finalize(max_try_times):
"""
return _CAPI_DGLRPCConnectReceiverFinalize(max_try_times)
def set_rank(rank):
"""Set the rank of this process.
......@@ -202,6 +243,7 @@ def set_rank(rank):
"""
_CAPI_DGLRPCSetRank(int(rank))
def get_rank():
"""Get the rank of this process.
......@@ -215,6 +257,7 @@ def get_rank():
"""
return _CAPI_DGLRPCGetRank()
def set_machine_id(machine_id):
"""Set current machine ID
......@@ -225,6 +268,7 @@ def set_machine_id(machine_id):
"""
_CAPI_DGLRPCSetMachineID(int(machine_id))
def get_machine_id():
"""Get current machine ID
......@@ -235,6 +279,7 @@ def get_machine_id():
"""
return _CAPI_DGLRPCGetMachineID()
def set_num_machines(num_machines):
"""Set number of machine
......@@ -245,6 +290,7 @@ def set_num_machines(num_machines):
"""
_CAPI_DGLRPCSetNumMachines(int(num_machines))
def get_num_machines():
"""Get number of machines
......@@ -255,36 +301,37 @@ def get_num_machines():
"""
return _CAPI_DGLRPCGetNumMachines()
def set_num_server(num_server):
"""Set the total number of server.
"""
"""Set the total number of server."""
_CAPI_DGLRPCSetNumServer(int(num_server))
def get_num_server():
"""Get the total number of server.
"""
"""Get the total number of server."""
return _CAPI_DGLRPCGetNumServer()
def set_num_client(num_client):
"""Set the total number of client.
"""
"""Set the total number of client."""
_CAPI_DGLRPCSetNumClient(int(num_client))
def get_num_client():
"""Get the total number of client.
"""
"""Get the total number of client."""
return _CAPI_DGLRPCGetNumClient()
def set_num_server_per_machine(num_server):
"""Set the total number of server per machine
"""
"""Set the total number of server per machine"""
_CAPI_DGLRPCSetNumServerPerMachine(num_server)
def get_num_server_per_machine():
"""Get the total number of server per machine
"""
"""Get the total number of server per machine"""
return _CAPI_DGLRPCGetNumServerPerMachine()
def incr_msg_seq():
"""Increment the message sequence number and return the old one.
......@@ -295,6 +342,7 @@ def incr_msg_seq():
"""
return _CAPI_DGLRPCIncrMsgSeq()
def get_msg_seq():
"""Get the current message sequence number.
......@@ -305,6 +353,7 @@ def get_msg_seq():
"""
return _CAPI_DGLRPCGetMsgSeq()
def set_msg_seq(msg_seq):
"""Set the current message sequence number.
......@@ -315,6 +364,7 @@ def set_msg_seq(msg_seq):
"""
_CAPI_DGLRPCSetMsgSeq(int(msg_seq))
def register_service(service_id, req_cls, res_cls=None):
"""Register a service to RPC.
......@@ -332,6 +382,7 @@ def register_service(service_id, req_cls, res_cls=None):
RESPONSE_CLASS_TO_SERVICE_ID[res_cls] = service_id
SERVICE_ID_TO_PROPERTY[service_id] = (req_cls, res_cls)
def get_service_property(service_id):
"""Get service property.
......@@ -347,6 +398,7 @@ def get_service_property(service_id):
"""
return SERVICE_ID_TO_PROPERTY[service_id]
class Request:
"""Base request class"""
......@@ -389,9 +441,14 @@ class Request:
cls = self.__class__
sid = REQUEST_CLASS_TO_SERVICE_ID.get(cls, None)
if sid is None:
raise DGLError('Request class {} has not been registered as a service.'.format(cls))
raise DGLError(
"Request class {} has not been registered as a service.".format(
cls
)
)
return sid
class Response:
"""Base response class"""
......@@ -417,9 +474,14 @@ class Response:
cls = self.__class__
sid = RESPONSE_CLASS_TO_SERVICE_ID.get(cls, None)
if sid is None:
raise DGLError('Response class {} has not been registered as a service.'.format(cls))
raise DGLError(
"Response class {} has not been registered as a service.".format(
cls
)
)
return sid
def serialize_to_payload(serializable):
"""Serialize an object to payloads.
......@@ -452,11 +514,14 @@ def serialize_to_payload(serializable):
data = bytearray(pickle.dumps((nonarray_pos, nonarray_state)))
return data, array_state
class PlaceHolder:
"""PlaceHolder object for deserialization"""
_PLACEHOLDER = PlaceHolder()
def deserialize_from_payload(cls, data, tensors):
"""Deserialize and reconstruct the object from payload.
......@@ -496,7 +561,8 @@ def deserialize_from_payload(cls, data, tensors):
obj.__setstate__(state)
return obj
@register_object('rpc.RPCMessage')
@register_object("rpc.RPCMessage")
class RPCMessage(ObjectBase):
"""Serialized RPC message that can be sent to remote processes.
......@@ -519,7 +585,17 @@ class RPCMessage(ObjectBase):
group_id : int
The group ID
"""
def __init__(self, service_id, msg_seq, client_id, server_id, data, tensors, group_id=0):
def __init__(
self,
service_id,
msg_seq,
client_id,
server_id,
data,
tensors,
group_id=0,
):
self.__init_handle_by_constructor__(
_CAPI_DGLRPCCreateRPCMessage,
int(service_id),
......@@ -528,7 +604,8 @@ class RPCMessage(ObjectBase):
int(server_id),
data,
[F.zerocopy_to_dgl_ndarray(tsor) for tsor in tensors],
int(group_id))
int(group_id),
)
@property
def service_id(self):
......@@ -566,6 +643,7 @@ class RPCMessage(ObjectBase):
"""Get group ID."""
return _CAPI_DGLRPCMessageGetGroupId(self)
def send_request(target, request):
"""Send one request to the target server.
......@@ -593,10 +671,18 @@ def send_request(target, request):
client_id = get_rank()
server_id = target
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id,
data, tensors, group_id=get_group_id())
msg = RPCMessage(
service_id,
msg_seq,
client_id,
server_id,
data,
tensors,
group_id=get_group_id(),
)
send_rpc_message(msg, server_id)
def send_request_to_machine(target, request):
"""Send one request to the target machine, which will randomly
select a server node to process this request.
......@@ -620,12 +706,17 @@ def send_request_to_machine(target, request):
service_id = request.service_id
msg_seq = incr_msg_seq()
client_id = get_rank()
server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1)
server_id = random.randint(
target * get_num_server_per_machine(),
(target + 1) * get_num_server_per_machine() - 1,
)
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id())
msg = RPCMessage(
service_id, msg_seq, client_id, server_id, data, tensors, get_group_id()
)
send_rpc_message(msg, server_id)
def send_response(target, response, group_id):
"""Send one response to the target client.
......@@ -655,9 +746,12 @@ def send_response(target, response, group_id):
client_id = target
server_id = get_rank()
data, tensors = serialize_to_payload(response)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, group_id)
msg = RPCMessage(
service_id, msg_seq, client_id, server_id, data, tensors, group_id
)
send_rpc_message(msg, get_client(client_id, group_id))
def recv_request(timeout=0):
"""Receive one request.
......@@ -690,14 +784,19 @@ def recv_request(timeout=0):
set_msg_seq(msg.msg_seq)
req_cls, _ = SERVICE_ID_TO_PROPERTY[msg.service_id]
if req_cls is None:
raise DGLError('Got request message from service ID {}, '
'but no request class is registered.'.format(msg.service_id))
raise DGLError(
"Got request message from service ID {}, "
"but no request class is registered.".format(msg.service_id)
)
req = deserialize_from_payload(req_cls, msg.data, msg.tensors)
if msg.server_id != get_rank():
raise DGLError('Got request sent to server {}, '
'different from my rank {}!'.format(msg.server_id, get_rank()))
raise DGLError(
"Got request sent to server {}, "
"different from my rank {}!".format(msg.server_id, get_rank())
)
return req, msg.client_id, msg.group_id
def recv_response(timeout=0):
"""Receive one response.
......@@ -725,17 +824,24 @@ def recv_response(timeout=0):
return None
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None:
raise DGLError('Got response message from service ID {}, '
'but no response class is registered.'.format(msg.service_id))
raise DGLError(
"Got response message from service ID {}, "
"but no response class is registered.".format(msg.service_id)
)
res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != get_rank() and get_rank() != -1:
raise DGLError('Got response of request sent by client {}, '
'different from my rank {}!'.format(msg.client_id, get_rank()))
raise DGLError(
"Got response of request sent by client {}, "
"different from my rank {}!".format(msg.client_id, get_rank())
)
if msg.group_id != get_group_id():
raise DGLError("Got response of request sent by group {}, "
"different from my group {}!".format(msg.group_id, get_group_id()))
raise DGLError(
"Got response of request sent by group {}, "
"different from my group {}!".format(msg.group_id, get_group_id())
)
return res
def remote_call(target_and_requests, timeout=0):
"""Invoke registered services on remote servers and collect responses.
......@@ -771,10 +877,20 @@ def remote_call(target_and_requests, timeout=0):
service_id = request.service_id
msg_seq = incr_msg_seq()
client_id = get_rank()
server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1)
server_id = random.randint(
target * get_num_server_per_machine(),
(target + 1) * get_num_server_per_machine() - 1,
)
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id())
msg = RPCMessage(
service_id,
msg_seq,
client_id,
server_id,
data,
tensors,
get_group_id(),
)
send_rpc_message(msg, server_id)
# check if has response
res_cls = get_service_property(service_id)[1]
......@@ -786,22 +902,28 @@ def remote_call(target_and_requests, timeout=0):
msg = recv_rpc_message(timeout)
if msg is None:
raise DGLError(
f"Timed out for receiving message within {timeout} milliseconds")
f"Timed out for receiving message within {timeout} milliseconds"
)
num_res -= 1
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None:
raise DGLError('Got response message from service ID {}, '
'but no response class is registered.'.format(msg.service_id))
raise DGLError(
"Got response message from service ID {}, "
"but no response class is registered.".format(msg.service_id)
)
res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != myrank:
raise DGLError('Got reponse of request sent by client {}, '
'different from my rank {}!'.format(msg.client_id, myrank))
raise DGLError(
"Got reponse of request sent by client {}, "
"different from my rank {}!".format(msg.client_id, myrank)
)
# set response
all_res[msgseq2pos[msg.msg_seq]] = res
return all_res
def send_requests_to_machine(target_and_requests):
""" Send requests to the remote machines.
"""Send requests to the remote machines.
This operation isn't block. It returns immediately once it sends all requests.
......@@ -824,10 +946,20 @@ def send_requests_to_machine(target_and_requests):
msg_seq = incr_msg_seq()
client_id = get_rank()
server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1)
server_id = random.randint(
target * get_num_server_per_machine(),
(target + 1) * get_num_server_per_machine() - 1,
)
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id())
msg = RPCMessage(
service_id,
msg_seq,
client_id,
server_id,
data,
tensors,
get_group_id(),
)
send_rpc_message(msg, server_id)
# check if has response
res_cls = get_service_property(service_id)[1]
......@@ -835,8 +967,9 @@ def send_requests_to_machine(target_and_requests):
msgseq2pos[msg_seq] = pos
return msgseq2pos
def recv_responses(msgseq2pos, timeout=0):
""" Receive responses
"""Receive responses
It returns the responses in the same order as the requests. The order of requests
are stored in msgseq2pos.
......@@ -866,20 +999,26 @@ def recv_responses(msgseq2pos, timeout=0):
msg = recv_rpc_message(timeout)
if msg is None:
raise DGLError(
f"Timed out for receiving message within {timeout} milliseconds")
f"Timed out for receiving message within {timeout} milliseconds"
)
num_res -= 1
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None:
raise DGLError('Got response message from service ID {}, '
'but no response class is registered.'.format(msg.service_id))
raise DGLError(
"Got response message from service ID {}, "
"but no response class is registered.".format(msg.service_id)
)
res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != myrank:
raise DGLError('Got reponse of request sent by client {}, '
'different from my rank {}!'.format(msg.client_id, myrank))
raise DGLError(
"Got reponse of request sent by client {}, "
"different from my rank {}!".format(msg.client_id, myrank)
)
# set response
all_res[msgseq2pos[msg.msg_seq]] = res
return all_res
def remote_call_to_machine(target_and_requests, timeout=0):
"""Invoke registered services on remote machine
(which will ramdom select a server to process the request) and collect responses.
......@@ -910,6 +1049,7 @@ def remote_call_to_machine(target_and_requests, timeout=0):
msgseq2pos = send_requests_to_machine(target_and_requests)
return recv_responses(msgseq2pos, timeout)
def send_rpc_message(msg, target):
"""Send one message to the target server.
......@@ -936,6 +1076,7 @@ def send_rpc_message(msg, target):
"""
_CAPI_DGLRPCSendRPCMessage(msg, int(target))
def recv_rpc_message(timeout=0):
"""Receive one message.
......@@ -960,23 +1101,34 @@ def recv_rpc_message(timeout=0):
status = _CAPI_DGLRPCRecvRPCMessage(timeout, msg)
return msg if status == 0 else None
def client_barrier():
"""Barrier all client processes"""
req = ClientBarrierRequest()
send_request(0, req)
res = recv_response()
assert res.msg == 'barrier'
assert res.msg == "barrier"
def finalize_server():
"""Finalize resources of current server
"""
"""Finalize resources of current server"""
finalize_sender()
finalize_receiver()
print("Server (%d) shutdown." % get_rank())
def fast_pull(name, id_tensor, part_id, service_id,
machine_count, group_count, machine_id,
client_id, local_data, policy):
def fast_pull(
name,
id_tensor,
part_id,
service_id,
machine_count,
group_count,
machine_id,
client_id,
local_data,
policy,
):
"""Fast-pull api used by kvstore.
Parameters
......@@ -1004,39 +1156,45 @@ def fast_pull(name, id_tensor, part_id, service_id,
"""
msg_seq = incr_msg_seq()
pickle_data = bytearray(pickle.dumps(([0], [name])))
global_id = _CAPI_DGLRPCGetGlobalIDFromLocalPartition(F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(part_id),
machine_id)
global_id = _CAPI_DGLRPCGetGlobalIDFromLocalPartition(
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(part_id),
machine_id,
)
global_id = F.zerocopy_from_dgl_ndarray(global_id)
g2l_id = policy.to_local(global_id)
res_tensor = _CAPI_DGLRPCFastPull(name,
int(machine_id),
int(machine_count),
int(group_count),
int(client_id),
int(service_id),
int(msg_seq),
pickle_data,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(part_id),
F.zerocopy_to_dgl_ndarray(g2l_id),
F.zerocopy_to_dgl_ndarray(local_data))
res_tensor = _CAPI_DGLRPCFastPull(
name,
int(machine_id),
int(machine_count),
int(group_count),
int(client_id),
int(service_id),
int(msg_seq),
pickle_data,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(part_id),
F.zerocopy_to_dgl_ndarray(g2l_id),
F.zerocopy_to_dgl_ndarray(local_data),
)
return F.zerocopy_from_dgl_ndarray(res_tensor)
def register_sig_handler():
"""Register for handling signal event.
"""
"""Register for handling signal event."""
_CAPI_DGLRPCHandleSignal()
def copy_data_to_shared_memory(dst, source):
"""Copy tensor data to shared-memory tensor
"""
"""Copy tensor data to shared-memory tensor"""
F.zerocopy_to_dgl_ndarray(dst).copyfrom(F.zerocopy_to_dgl_ndarray(source))
############### Some basic services will be defined here #############
CLIENT_REGISTER = 22451
class ClientRegisterRequest(Request):
"""This request will send client's ip to server.
......@@ -1045,6 +1203,7 @@ class ClientRegisterRequest(Request):
ip_addr : str
client's IP address
"""
def __init__(self, ip_addr):
self.ip_addr = ip_addr
......@@ -1055,7 +1214,8 @@ class ClientRegisterRequest(Request):
self.ip_addr = state
def process_request(self, server_state):
return None # do nothing
return None # do nothing
class ClientRegisterResponse(Response):
"""This response will send assigned ID to client.
......@@ -1065,6 +1225,7 @@ class ClientRegisterResponse(Response):
ID : int
client's ID
"""
def __init__(self, client_id):
self.client_id = client_id
......@@ -1077,6 +1238,7 @@ class ClientRegisterResponse(Response):
SHUT_DOWN_SERVER = 22452
class ShutDownRequest(Request):
"""Client send this request to shut-down a server.
......@@ -1087,6 +1249,7 @@ class ShutDownRequest(Request):
client_id : int
client's ID
"""
def __init__(self, client_id, force_shutdown_server=False):
self.client_id = client_id
self.force_shutdown_server = force_shutdown_server
......@@ -1104,8 +1267,10 @@ class ShutDownRequest(Request):
finalize_server()
return SERVER_EXIT
GET_NUM_CLIENT = 22453
class GetNumberClientsResponse(Response):
"""This reponse will send total number of clients.
......@@ -1114,6 +1279,7 @@ class GetNumberClientsResponse(Response):
num_client : int
total number of clients
"""
def __init__(self, num_client):
self.num_client = num_client
......@@ -1123,6 +1289,7 @@ class GetNumberClientsResponse(Response):
def __setstate__(self, state):
self.num_client = state
class GetNumberClientsRequest(Request):
"""Client send this request to get the total number of client.
......@@ -1131,6 +1298,7 @@ class GetNumberClientsRequest(Request):
client_id : int
client's ID
"""
def __init__(self, client_id):
self.client_id = client_id
......@@ -1144,8 +1312,10 @@ class GetNumberClientsRequest(Request):
res = GetNumberClientsResponse(get_num_client())
return res
CLIENT_BARRIER = 22454
class ClientBarrierResponse(Response):
"""Send the barrier confirmation to client
......@@ -1154,7 +1324,8 @@ class ClientBarrierResponse(Response):
msg : str
string msg
"""
def __init__(self, msg='barrier'):
def __init__(self, msg="barrier"):
self.msg = msg
def __getstate__(self):
......@@ -1163,6 +1334,7 @@ class ClientBarrierResponse(Response):
def __setstate__(self, state):
self.msg = state
class ClientBarrierRequest(Request):
"""Send the barrier information to server
......@@ -1171,7 +1343,8 @@ class ClientBarrierRequest(Request):
msg : str
string msg
"""
def __init__(self, msg='barrier'):
def __init__(self, msg="barrier"):
self.msg = msg
self.group_id = get_group_id()
......@@ -1182,7 +1355,9 @@ class ClientBarrierRequest(Request):
self.msg, self.group_id = state
def process_request(self, server_state):
_CAPI_DGLRPCSetBarrierCount(_CAPI_DGLRPCGetBarrierCount(self.group_id)+1, self.group_id)
_CAPI_DGLRPCSetBarrierCount(
_CAPI_DGLRPCGetBarrierCount(self.group_id) + 1, self.group_id
)
if _CAPI_DGLRPCGetBarrierCount(self.group_id) == get_num_client():
_CAPI_DGLRPCSetBarrierCount(0, self.group_id)
res_list = []
......@@ -1191,6 +1366,7 @@ class ClientBarrierRequest(Request):
return res_list
return None
def set_group_id(group_id):
"""Set current group ID
......@@ -1201,6 +1377,7 @@ def set_group_id(group_id):
"""
_CAPI_DGLRPCSetGroupID(int(group_id))
def get_group_id():
"""Get current group ID
......@@ -1211,6 +1388,7 @@ def get_group_id():
"""
return _CAPI_DGLRPCGetGroupID()
def register_client(client_id, group_id):
"""Register client
......@@ -1221,6 +1399,7 @@ def register_client(client_id, group_id):
"""
return _CAPI_DGLRPCRegisterClient(int(client_id), int(group_id))
def get_client(client_id, group_id):
"""Get global client ID
......@@ -1238,6 +1417,7 @@ def get_client(client_id, group_id):
"""
return _CAPI_DGLRPCGetClient(int(client_id), int(group_id))
class DistConnectError(DGLError):
"""Exception raised for errors if fail to connect peer.
......@@ -1247,12 +1427,16 @@ class DistConnectError(DGLError):
reference for KVServer
"""
def __init__(self, max_try_times, ip='', port=''):
peer_str = "peer[{}:{}]".format(ip, port) if ip != '' else "peer"
self.message = "Failed to build conncetion with {} after {} retries. " \
"Please check network availability or increase max try " \
"times via 'DGL_DIST_MAX_TRY_TIMES'.".format(
peer_str, max_try_times)
def __init__(self, max_try_times, ip="", port=""):
peer_str = "peer[{}:{}]".format(ip, port) if ip != "" else "peer"
self.message = (
"Failed to build conncetion with {} after {} retries. "
"Please check network availability or increase max try "
"times via 'DGL_DIST_MAX_TRY_TIMES'.".format(
peer_str, max_try_times
)
)
super().__init__(self.message)
_init_api("dgl.distributed.rpc")
"""Functions used by client."""
import os
import socket
import atexit
import logging
import os
import socket
import time
from . import rpc
from .constants import MAX_QUEUE_SIZE
if os.name != 'nt':
if os.name != "nt":
import fcntl
import struct
def local_ip4_addr_list():
"""Return a set of IPv4 address
......@@ -20,21 +21,25 @@ def local_ip4_addr_list():
`logging.getLogger("dgl-distributed-socket").setLevel(logging.WARNING+1)`
to disable the warning here
"""
assert os.name != 'nt', 'Do not support Windows rpc yet.'
assert os.name != "nt", "Do not support Windows rpc yet."
nic = set()
logger = logging.getLogger("dgl-distributed-socket")
for if_nidx in socket.if_nameindex():
name = if_nidx[1]
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
ip_of_ni = fcntl.ioctl(sock.fileno(),
0x8915, # SIOCGIFADDR
struct.pack('256s', name[:15].encode("UTF-8")))
ip_of_ni = fcntl.ioctl(
sock.fileno(),
0x8915, # SIOCGIFADDR
struct.pack("256s", name[:15].encode("UTF-8")),
)
except OSError as e:
if e.errno == 99: # EADDRNOTAVAIL
if e.errno == 99: # EADDRNOTAVAIL
logger.warning(
"Warning! Interface: %s \n"
"IP address not available for interface.", name)
"IP address not available for interface.",
name,
)
continue
raise e
......@@ -42,6 +47,7 @@ def local_ip4_addr_list():
nic.add(ip_addr)
return nic
def get_local_machine_id(server_namebook):
"""Given server_namebook, find local machine ID
......@@ -76,6 +82,7 @@ def get_local_machine_id(server_namebook):
break
return res
def get_local_usable_addr(probe_addr):
"""Get local usable IP and port
......@@ -90,7 +97,7 @@ def get_local_usable_addr(probe_addr):
sock.connect((probe_addr, 1))
ip_addr = sock.getsockname()[0]
except ValueError:
ip_addr = '127.0.0.1'
ip_addr = "127.0.0.1"
finally:
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
......@@ -99,11 +106,16 @@ def get_local_usable_addr(probe_addr):
port = sock.getsockname()[1]
sock.close()
return ip_addr + ':' + str(port)
return ip_addr + ":" + str(port)
def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
net_type='socket', group_id=0):
def connect_to_server(
ip_config,
num_servers,
max_queue_size=MAX_QUEUE_SIZE,
net_type="socket",
group_id=0,
):
"""Connect this client to server.
Parameters
......@@ -127,23 +139,30 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
------
ConnectionError : If anything wrong with the connection.
"""
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size
assert net_type in ('socket', 'tensorpipe'), \
'net_type (%s) can only be \'socket\' or \'tensorpipe\'.' % net_type
assert num_servers > 0, (
"num_servers (%d) must be a positive number." % num_servers
)
assert max_queue_size > 0, (
"queue_size (%d) cannot be a negative number." % max_queue_size
)
assert net_type in ("socket", "tensorpipe"), (
"net_type (%s) can only be 'socket' or 'tensorpipe'." % net_type
)
# Register some basic service
rpc.register_service(rpc.CLIENT_REGISTER,
rpc.ClientRegisterRequest,
rpc.ClientRegisterResponse)
rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest,
None)
rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse)
rpc.register_service(rpc.CLIENT_BARRIER,
rpc.ClientBarrierRequest,
rpc.ClientBarrierResponse)
rpc.register_service(
rpc.CLIENT_REGISTER,
rpc.ClientRegisterRequest,
rpc.ClientRegisterResponse,
)
rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ShutDownRequest, None)
rpc.register_service(
rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse,
)
rpc.register_service(
rpc.CLIENT_BARRIER, rpc.ClientBarrierRequest, rpc.ClientBarrierResponse
)
rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook)
......@@ -157,7 +176,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
if server_info[0] > max_machine_id:
max_machine_id = server_info[0]
rpc.set_num_server_per_machine(group_count[0])
num_machines = max_machine_id+1
num_machines = max_machine_id + 1
rpc.set_num_machines(num_machines)
machine_id = get_local_machine_id(server_namebook)
rpc.set_machine_id(machine_id)
......@@ -165,7 +184,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
rpc.create_sender(max_queue_size, net_type)
rpc.create_receiver(max_queue_size, net_type)
# Get connected with all server nodes
max_try_times = int(os.environ.get('DGL_DIST_MAX_TRY_TIMES', 1024))
max_try_times = int(os.environ.get("DGL_DIST_MAX_TRY_TIMES", 1024))
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
......@@ -173,39 +192,50 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
while not rpc.connect_receiver(server_ip, server_port, server_id):
try_times += 1
if try_times % 200 == 0:
print("Client is trying to connect server receiver: {}:{}".format(
server_ip, server_port))
print(
"Client is trying to connect server receiver: {}:{}".format(
server_ip, server_port
)
)
if try_times >= max_try_times:
raise rpc.DistConnectError(max_try_times, server_ip, server_port)
raise rpc.DistConnectError(
max_try_times, server_ip, server_port
)
time.sleep(3)
if not rpc.connect_receiver_finalize(max_try_times):
raise rpc.DistConnectError(max_try_times)
# Get local usable IP address and port
ip_addr = get_local_usable_addr(server_ip)
client_ip, client_port = ip_addr.split(':')
client_ip, client_port = ip_addr.split(":")
# Register client on server
register_req = rpc.ClientRegisterRequest(ip_addr)
for server_id in range(num_servers):
rpc.send_request(server_id, register_req)
# wait server connect back
rpc.wait_for_senders(client_ip, client_port, num_servers,
blocking=net_type == 'socket')
print("Client [{}] waits on {}:{}".format(
os.getpid(), client_ip, client_port))
rpc.wait_for_senders(
client_ip, client_port, num_servers, blocking=net_type == "socket"
)
print(
"Client [{}] waits on {}:{}".format(os.getpid(), client_ip, client_port)
)
# recv client ID from server
res = rpc.recv_response()
rpc.set_rank(res.client_id)
print("Machine (%d) group (%d) client (%d) connect to server successfuly!" \
% (machine_id, group_id, rpc.get_rank()))
print(
"Machine (%d) group (%d) client (%d) connect to server successfuly!"
% (machine_id, group_id, rpc.get_rank())
)
# get total number of client
get_client_num_req = rpc.GetNumberClientsRequest(rpc.get_rank())
rpc.send_request(0, get_client_num_req)
res = rpc.recv_response()
rpc.set_num_client(res.num_client)
from .dist_context import exit_client, set_initialized
atexit.register(exit_client)
set_initialized(True)
def shutdown_servers(ip_config, num_servers):
"""Issue commands to remote servers to shut them down.
......@@ -229,13 +259,11 @@ def shutdown_servers(ip_config, num_servers):
------
ConnectionError : If anything wrong with the connection.
"""
rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest,
None)
rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ShutDownRequest, None)
rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook)
rpc.create_sender(MAX_QUEUE_SIZE, 'tensorpipe')
rpc.create_sender(MAX_QUEUE_SIZE, "tensorpipe")
# Get connected with all server nodes
for server_id, addr in server_namebook.items():
server_ip = addr[1]
......
"""Functions used by server."""
import time
import os
import time
from ..base import DGLError
from . import rpc
from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE
def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
def start_server(
server_id,
ip_config,
num_servers,
num_clients,
server_state,
max_queue_size=MAX_QUEUE_SIZE,
net_type="socket",
):
"""Start DGL server, which will be shared with all the rpc services.
This is a blocking function -- it returns only when the server shutdown.
......@@ -34,33 +43,47 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
net_type : str
Networking type. Current options are: ``'socket'`` or ``'tensorpipe'``.
"""
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_clients
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size
assert net_type in ('socket', 'tensorpipe'), \
'net_type (%s) can only be \'socket\' or \'tensorpipe\'' % net_type
assert server_id >= 0, (
"server_id (%d) cannot be a negative number." % server_id
)
assert num_servers > 0, (
"num_servers (%d) must be a positive number." % num_servers
)
assert num_clients >= 0, (
"num_client (%d) cannot be a negative number." % num_clients
)
assert max_queue_size > 0, (
"queue_size (%d) cannot be a negative number." % max_queue_size
)
assert net_type in ("socket", "tensorpipe"), (
"net_type (%s) can only be 'socket' or 'tensorpipe'" % net_type
)
if server_state.keep_alive:
assert net_type == 'tensorpipe', \
"net_type can only be 'tensorpipe' if 'keep_alive' is enabled."
print("As configured, this server will keep alive for multiple"
" client groups until force shutdown request is received."
" [WARNING] This feature is experimental and not fully tested.")
assert (
net_type == "tensorpipe"
), "net_type can only be 'tensorpipe' if 'keep_alive' is enabled."
print(
"As configured, this server will keep alive for multiple"
" client groups until force shutdown request is received."
" [WARNING] This feature is experimental and not fully tested."
)
# Register signal handler.
rpc.register_sig_handler()
# Register some basic services
rpc.register_service(rpc.CLIENT_REGISTER,
rpc.ClientRegisterRequest,
rpc.ClientRegisterResponse)
rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest,
None)
rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse)
rpc.register_service(rpc.CLIENT_BARRIER,
rpc.ClientBarrierRequest,
rpc.ClientBarrierResponse)
rpc.register_service(
rpc.CLIENT_REGISTER,
rpc.ClientRegisterRequest,
rpc.ClientRegisterResponse,
)
rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ShutDownRequest, None)
rpc.register_service(
rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse,
)
rpc.register_service(
rpc.CLIENT_BARRIER, rpc.ClientBarrierRequest, rpc.ClientBarrierResponse
)
rpc.set_rank(server_id)
server_namebook = rpc.read_ip_config(ip_config, num_servers)
machine_id = server_namebook[server_id][0]
......@@ -73,9 +96,11 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# Once all the senders connect to server, server will not
# accept new sender's connection
print(
"Server is waiting for connections on [{}:{}]...".format(ip_addr, port))
rpc.wait_for_senders(ip_addr, port, num_clients,
blocking=net_type == 'socket')
"Server is waiting for connections on [{}:{}]...".format(ip_addr, port)
)
rpc.wait_for_senders(
ip_addr, port, num_clients, blocking=net_type == "socket"
)
rpc.set_num_client(num_clients)
recv_clients = {}
while True:
......@@ -89,18 +114,25 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# a new client group is ready
ips.sort()
client_namebook = dict(enumerate(ips))
time.sleep(3) # wait for clients' receivers ready
max_try_times = int(os.environ.get('DGL_DIST_MAX_TRY_TIMES', 120))
time.sleep(3) # wait for clients' receivers ready
max_try_times = int(os.environ.get("DGL_DIST_MAX_TRY_TIMES", 120))
for client_id, addr in client_namebook.items():
client_ip, client_port = addr.split(':')
client_ip, client_port = addr.split(":")
try_times = 0
while not rpc.connect_receiver(client_ip, client_port, client_id, group_id):
while not rpc.connect_receiver(
client_ip, client_port, client_id, group_id
):
try_times += 1
if try_times % 200 == 0:
print("Server~{} is trying to connect client receiver: {}:{}".format(
server_id, client_ip, client_port))
print(
"Server~{} is trying to connect client receiver: {}:{}".format(
server_id, client_ip, client_port
)
)
if try_times >= max_try_times:
raise rpc.DistConnectError(max_try_times, client_ip, client_port)
raise rpc.DistConnectError(
max_try_times, client_ip, client_port
)
time.sleep(1)
if not rpc.connect_receiver_finalize(max_try_times):
raise rpc.DistConnectError(max_try_times)
......@@ -130,7 +162,11 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
print("Server is exiting...")
return
elif res == SERVER_KEEP_ALIVE:
print("Server keeps alive while client group~{} is exiting...".format(group_id))
print(
"Server keeps alive while client group~{} is exiting...".format(
group_id
)
)
else:
raise DGLError("Unexpected response: {}".format(res))
else:
......
......@@ -77,4 +77,5 @@ class ServerState:
"""Flag of whether keep alive"""
return self._keep_alive
_init_api("dgl.distributed.server_state")
"""Define utility functions for shared memory."""
from .. import backend as F
from .._ffi.ndarray import empty_shared_mem
from .. import ndarray as nd
from .._ffi.ndarray import empty_shared_mem
DTYPE_DICT = F.data_type_dict
DTYPE_DICT = {DTYPE_DICT[key]:key for key in DTYPE_DICT}
DTYPE_DICT = {DTYPE_DICT[key]: key for key in DTYPE_DICT}
def _get_ndata_path(graph_name, ndata_name):
return "/" + graph_name + "_node_" + ndata_name
def _get_edata_path(graph_name, edata_name):
return "/" + graph_name + "_edge_" + edata_name
def _to_shared_mem(arr, name):
dlpack = F.zerocopy_to_dlpack(arr)
dgl_tensor = nd.from_dlpack(dlpack)
new_arr = empty_shared_mem(name, True, F.shape(arr), DTYPE_DICT[F.dtype(arr)])
new_arr = empty_shared_mem(
name, True, F.shape(arr), DTYPE_DICT[F.dtype(arr)]
)
dgl_tensor.copyto(new_arr)
dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack)
......@@ -10,6 +10,7 @@ from .init import zero_initializer
from .storages import TensorStorage
from .utils import gather_pinned_tensor_rows, pin_memory_inplace
class _LazyIndex(object):
def __init__(self, index):
if isinstance(index, list):
......@@ -21,17 +22,17 @@ class _LazyIndex(object):
return len(self._indices[-1])
def slice(self, index):
""" Create a new _LazyIndex object sliced by the given index tensor.
"""
"""Create a new _LazyIndex object sliced by the given index tensor."""
# if our indices are in the same context, lets just slice now and free
# memory, otherwise do nothing until we have to
if F.context(self._indices[-1]) == F.context(index):
return _LazyIndex(self._indices[:-1] + [F.gather_row(self._indices[-1], index)])
return _LazyIndex(
self._indices[:-1] + [F.gather_row(self._indices[-1], index)]
)
return _LazyIndex(self._indices + [index])
def flatten(self):
""" Evaluate the chain of indices, and return a single index tensor.
"""
"""Evaluate the chain of indices, and return a single index tensor."""
flat_index = self._indices[0]
# here we actually need to resolve it
for index in self._indices[1:]:
......@@ -40,6 +41,7 @@ class _LazyIndex(object):
flat_index = F.gather_row(flat_index, index)
return flat_index
class LazyFeature(object):
"""Placeholder for feature prefetching.
......@@ -81,12 +83,16 @@ class LazyFeature(object):
id_ : Tensor, optional
The ID tensor.
"""
__slots__ = ['name', 'id_']
__slots__ = ["name", "id_"]
def __init__(self, name=None, id_=None):
self.name = name
self.id_ = id_
def to(self, *args, **kwargs): # pylint: disable=invalid-name, unused-argument
def to(
self, *args, **kwargs
): # pylint: disable=invalid-name, unused-argument
"""No-op. For compatibility of :meth:`Frame.to` method."""
return self
......@@ -104,7 +110,8 @@ class LazyFeature(object):
def record_stream(self, stream):
"""No-op. For compatibility of :meth:`Frame.record_stream` method."""
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
class Scheme(namedtuple("Scheme", ["shape", "dtype"])):
"""The column scheme.
Parameters
......@@ -114,6 +121,7 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
dtype : backend-specific type object
The feature data type.
"""
# Pickling torch dtypes could be problemetic; this is a workaround.
# I also have to create data_type_dict and reverse_data_type_dict
# attribute just for this bug.
......@@ -128,6 +136,7 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
dtype = F.data_type_dict[dtype_str]
return cls(shape, dtype)
def infer_scheme(tensor):
"""Infer column scheme from the given tensor data.
......@@ -143,6 +152,7 @@ def infer_scheme(tensor):
"""
return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor))
class Column(TensorStorage):
"""A column is a compact store of features of multiple nodes/edges.
......@@ -185,6 +195,7 @@ class Column(TensorStorage):
index : Tensor
Index tensor
"""
def __init__(self, storage, *args, **kwargs):
super().__init__(storage)
self._init(*args, **kwargs)
......@@ -212,8 +223,14 @@ class Column(TensorStorage):
index_ctx = F.context(self.index)
# If under the special case where the storage is pinned and the index is on
# CUDA, directly call UVA slicing (even if they aree not in the same context).
if storage_ctx != index_ctx and storage_ctx == F.cpu() and F.is_pinned(self.storage):
self.storage = gather_pinned_tensor_rows(self.storage, self.index)
if (
storage_ctx != index_ctx
and storage_ctx == F.cpu()
and F.is_pinned(self.storage)
):
self.storage = gather_pinned_tensor_rows(
self.storage, self.index
)
else:
# If index and storage is not in the same context,
# copy index to the same context of storage.
......@@ -228,7 +245,9 @@ class Column(TensorStorage):
# move data to the right device
if self.device is not None:
self.storage = F.copy_to(self.storage, self.device[0], **self.device[1])
self.storage = F.copy_to(
self.storage, self.device[0], **self.device[1]
)
self.device = None
# convert data to the right type
......@@ -247,8 +266,8 @@ class Column(TensorStorage):
self._data_nd = None # should unpin data if it was pinned.
self.pinned_by_dgl = False
def to(self, device, **kwargs): # pylint: disable=invalid-name
""" Return a new column with columns copy to the targeted device (cpu/gpu).
def to(self, device, **kwargs): # pylint: disable=invalid-name
"""Return a new column with columns copy to the targeted device (cpu/gpu).
Parameters
----------
......@@ -268,13 +287,13 @@ class Column(TensorStorage):
@property
def dtype(self):
""" Return the effective data type of this Column """
"""Return the effective data type of this Column"""
if self.deferred_dtype is not None:
return self.deferred_dtype
return self.storage.dtype
def astype(self, new_dtype):
""" Return a new column such that when its data is requested,
"""Return a new column such that when its data is requested,
it will be converted to new_dtype.
Parameters
......@@ -353,8 +372,10 @@ class Column(TensorStorage):
"""
feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme))
raise DGLError(
"Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme)
)
self.data = F.scatter_row(self.data, rowids, feats)
def extend(self, feats, feat_scheme=None):
......@@ -373,14 +394,22 @@ class Column(TensorStorage):
feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme))
raise DGLError(
"Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme)
)
self.data = F.cat([self.data, feats], dim=0)
def clone(self):
"""Return a shallow copy of this column."""
return Column(self.storage, self.scheme, self.index, self.device, self.deferred_dtype)
return Column(
self.storage,
self.scheme,
self.index,
self.device,
self.deferred_dtype,
)
def deepclone(self):
"""Return a deepcopy of this column.
......@@ -409,13 +438,25 @@ class Column(TensorStorage):
Sub-column
"""
if self.index is None:
return Column(self.storage, self.scheme, rowids, self.device, self.deferred_dtype)
return Column(
self.storage,
self.scheme,
rowids,
self.device,
self.deferred_dtype,
)
else:
index = self.index
if not isinstance(index, _LazyIndex):
index = _LazyIndex(self.index)
index = index.slice(rowids)
return Column(self.storage, self.scheme, index, self.device, self.deferred_dtype)
return Column(
self.storage,
self.scheme,
index,
self.device,
self.deferred_dtype,
)
@staticmethod
def create(data):
......@@ -435,30 +476,32 @@ class Column(TensorStorage):
state = self.__dict__.copy()
# data pinning does not get serialized, so we need to remove that from
# the state
state['_data_nd'] = None
state['pinned_by_dgl'] = False
state["_data_nd"] = None
state["pinned_by_dgl"] = False
return state
def __setstate__(self, state):
index = None
device = None
if 'storage' in state and state['storage'] is not None:
assert 'index' not in state or state['index'] is None
assert 'device' not in state or state['device'] is None
if "storage" in state and state["storage"] is not None:
assert "index" not in state or state["index"] is None
assert "device" not in state or state["device"] is None
else:
# we may have a column with only index information, and that is
# valid
index = None if 'index' not in state else state['index']
device = None if 'device' not in state else state['device']
assert 'deferred_dtype' not in state or state['deferred_dtype'] is None
assert 'pinned_by_dgl' not in state or state['pinned_by_dgl'] is False
assert '_data_nd' not in state or state['_data_nd'] is None
index = None if "index" not in state else state["index"]
device = None if "device" not in state else state["device"]
assert "deferred_dtype" not in state or state["deferred_dtype"] is None
assert "pinned_by_dgl" not in state or state["pinned_by_dgl"] is False
assert "_data_nd" not in state or state["_data_nd"] is None
self.__dict__ = state
# properly initialize this object
self._init(self.scheme if hasattr(self, 'scheme') else None,
index=index,
device=device)
self._init(
self.scheme if hasattr(self, "scheme") else None,
index=index,
device=device,
)
def _init(self, scheme=None, index=None, device=None, deferred_dtype=None):
self.scheme = scheme if scheme else infer_scheme(self.storage)
......@@ -472,7 +515,7 @@ class Column(TensorStorage):
return self.clone()
def fetch(self, indices, device, pin_memory=False, **kwargs):
_ = self.data # materialize in case of lazy slicing & data transfer
_ = self.data # materialize in case of lazy slicing & data transfer
return super().fetch(indices, device, pin_memory=pin_memory, **kwargs)
def pin_memory_(self):
......@@ -503,10 +546,11 @@ class Column(TensorStorage):
----------
stream : torch.cuda.Stream.
"""
if F.get_preferred_backend() != 'pytorch':
if F.get_preferred_backend() != "pytorch":
raise DGLError("record_stream only supports the PyTorch backend.")
self.data.record_stream(stream)
class Frame(MutableMapping):
"""The columnar storage for node/edge features.
......@@ -523,6 +567,7 @@ class Frame(MutableMapping):
The number of rows in this frame. If ``data`` is provided and is not empty,
``num_rows`` will be ignored and inferred from the given data.
"""
def __init__(self, data=None, num_rows=None):
if data is None:
self._columns = dict()
......@@ -531,8 +576,10 @@ class Frame(MutableMapping):
assert not isinstance(data, Frame) # sanity check for code refactor
# Note that we always create a new column for the given data.
# This avoids two frames accidentally sharing the same column.
self._columns = {k : v if isinstance(v, LazyFeature) else Column.create(v)
for k, v in data.items()}
self._columns = {
k: v if isinstance(v, LazyFeature) else Column.create(v)
for k, v in data.items()
}
self._num_rows = num_rows
# infer num_rows & sanity check
for name, col in self._columns.items():
......@@ -541,8 +588,10 @@ class Frame(MutableMapping):
if self._num_rows is None:
self._num_rows = len(col)
elif len(col) != self._num_rows:
raise DGLError('Expected all columns to have same # rows (%d), '
'got %d on %r.' % (self._num_rows, len(col), name))
raise DGLError(
"Expected all columns to have same # rows (%d), "
"got %d on %r." % (self._num_rows, len(col), name)
)
# Initializer for empty values. Initializer is a callable.
# If is none, then a warning will be raised
......@@ -590,7 +639,7 @@ class Frame(MutableMapping):
@property
def schemes(self):
"""Return a dictionary of column name to column schemes."""
return {k : col.scheme for k, col in self._columns.items()}
return {k: col.scheme for k, col in self._columns.items()}
@property
def num_columns(self):
......@@ -658,14 +707,21 @@ class Frame(MutableMapping):
The column context.
"""
if name in self:
dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name)
dgl_warning(
'Column "%s" already exists. Ignore adding this column again.'
% name
)
return
if self.get_initializer(name) is None:
self._set_zero_default_initializer()
initializer = self.get_initializer(name)
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(0, self.num_rows))
init_data = initializer(
(self.num_rows,) + scheme.shape,
scheme.dtype,
ctx,
slice(0, self.num_rows),
)
self._columns[name] = Column(init_data, scheme)
def add_rows(self, num_rows):
......@@ -686,8 +742,12 @@ class Frame(MutableMapping):
if self.get_initializer(key) is None:
self._set_zero_default_initializer()
initializer = self.get_initializer(key)
new_data = initializer((num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(self._num_rows, self._num_rows + num_rows))
new_data = initializer(
(num_rows,) + scheme.shape,
scheme.dtype,
ctx,
slice(self._num_rows, self._num_rows + num_rows),
)
feat_placeholders[key] = new_data
self._append(Frame(feat_placeholders))
self._num_rows += num_rows
......@@ -708,8 +768,10 @@ class Frame(MutableMapping):
col = Column.create(data)
if len(col) != self.num_rows:
raise DGLError('Expected data to have %d rows, got %d.' %
(self.num_rows, len(col)))
raise DGLError(
"Expected data to have %d rows, got %d."
% (self.num_rows, len(col))
)
self._columns[name] = col
def update_row(self, rowids, data):
......@@ -747,9 +809,12 @@ class Frame(MutableMapping):
if self.get_initializer(key) is None:
self._set_zero_default_initializer()
initializer = self.get_initializer(key)
new_data = initializer((other.num_rows,) + scheme.shape,
scheme.dtype, ctx,
slice(self._num_rows, self._num_rows + other.num_rows))
new_data = initializer(
(other.num_rows,) + scheme.shape,
scheme.dtype,
ctx,
slice(self._num_rows, self._num_rows + other.num_rows),
)
other[key] = new_data
# append other to self
for key, col in other._columns.items():
......@@ -829,8 +894,10 @@ class Frame(MutableMapping):
Frame
A deep-cloned frame.
"""
newframe = Frame({k : col.deepclone() for k, col in self._columns.items()},
self._num_rows)
newframe = Frame(
{k: col.deepclone() for k, col in self._columns.items()},
self._num_rows,
)
newframe._initializers = self._initializers
newframe._default_initializer = self._default_initializer
return newframe
......@@ -851,14 +918,14 @@ class Frame(MutableMapping):
Frame
A new subframe.
"""
subcols = {k : col.subcolumn(rowids) for k, col in self._columns.items()}
subcols = {k: col.subcolumn(rowids) for k, col in self._columns.items()}
subf = Frame(subcols, len(rowids))
subf._initializers = self._initializers
subf._default_initializer = self._default_initializer
return subf
def to(self, device, **kwargs): # pylint: disable=invalid-name
""" Return a new frame with columns copy to the targeted device (cpu/gpu).
def to(self, device, **kwargs): # pylint: disable=invalid-name
"""Return a new frame with columns copy to the targeted device (cpu/gpu).
Parameters
----------
......@@ -873,7 +940,10 @@ class Frame(MutableMapping):
A new frame
"""
newframe = self.clone()
new_columns = {key : col.to(device, **kwargs) for key, col in newframe._columns.items()}
new_columns = {
key: col.to(device, **kwargs)
for key, col in newframe._columns.items()
}
newframe._columns = new_columns
return newframe
......@@ -899,8 +969,11 @@ class Frame(MutableMapping):
column.record_stream(stream)
def _astype_float(self, new_type):
assert new_type in [F.float64, F.float32, F.float16], \
"'new_type' must be floating-point type: %s" % str(new_type)
assert new_type in [
F.float64,
F.float32,
F.float16,
], "'new_type' must be floating-point type: %s" % str(new_type)
newframe = self.clone()
new_columns = {}
for name, column in self._columns.items():
......@@ -913,16 +986,16 @@ class Frame(MutableMapping):
return newframe
def half(self):
""" Return a new frame with all floating-point columns converted
to half-precision (float16) """
"""Return a new frame with all floating-point columns converted
to half-precision (float16)"""
return self._astype_float(F.float16)
def float(self):
""" Return a new frame with all floating-point columns converted
to single-precision (float32) """
"""Return a new frame with all floating-point columns converted
to single-precision (float32)"""
return self._astype_float(F.float32)
def double(self):
""" Return a new frame with all floating-point columns converted
to double-precision (float64) """
"""Return a new frame with all floating-point columns converted
to double-precision (float64)"""
return self._astype_float(F.float64)
......@@ -2,6 +2,6 @@
# pylint: disable=redefined-builtin
from __future__ import absolute_import
from .base import *
from .message import *
from .reducer import *
from .base import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment