Unverified Commit a208e886 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4680)



* [Misc] Black auto fix.

* fix pylint disable
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 29434e65
"""DGL distributed module""" """DGL distributed module"""
from .dist_graph import DistGraphServer, DistGraph, node_split, edge_split from . import optim
from .dist_context import exit_client, initialize
from .dist_dataloader import DistDataLoader
from .dist_graph import DistGraph, DistGraphServer, edge_split, node_split
from .dist_tensor import DistTensor from .dist_tensor import DistTensor
from .partition import partition_graph, load_partition, load_partition_feats, load_partition_book
from .graph_partition_book import GraphPartitionBook, PartitionPolicy from .graph_partition_book import GraphPartitionBook, PartitionPolicy
from .graph_services import *
from .kvstore import KVClient, KVServer
from .nn import * from .nn import *
from . import optim from .partition import (
load_partition,
load_partition_book,
load_partition_feats,
partition_graph,
)
from .rpc import * from .rpc import *
from .rpc_server import start_server
from .rpc_client import connect_to_server, shutdown_servers from .rpc_client import connect_to_server, shutdown_servers
from .dist_context import initialize, exit_client from .rpc_server import start_server
from .kvstore import KVServer, KVClient
from .server_state import ServerState from .server_state import ServerState
from .dist_dataloader import DistDataLoader
from .graph_services import *
"""Define all the constants used by DGL rpc""" """Define all the constants used by DGL rpc"""
# Maximum size of message queue in bytes # Maximum size of message queue in bytes
MAX_QUEUE_SIZE = 20*1024*1024*1024 MAX_QUEUE_SIZE = 20 * 1024 * 1024 * 1024
SERVER_EXIT = "server_exit" SERVER_EXIT = "server_exit"
SERVER_KEEP_ALIVE = "server_keep_alive" SERVER_KEEP_ALIVE = "server_keep_alive"
DEFAULT_NTYPE = '_N' DEFAULT_NTYPE = "_N"
DEFAULT_ETYPE = (DEFAULT_NTYPE, '_E', DEFAULT_NTYPE) DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE)
"""Initialize the distributed services""" """Initialize the distributed services"""
# pylint: disable=line-too-long # pylint: disable=line-too-long
import multiprocessing as mp
import traceback
import atexit import atexit
import time import gc
import multiprocessing as mp
import os import os
import sys
import queue import queue
import gc import sys
import time
import traceback
from enum import Enum from enum import Enum
from .. import utils
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE
from .kvstore import init_kvstore, close_kvstore from .kvstore import close_kvstore, init_kvstore
from .rpc_client import connect_to_server
from .role import init_role from .role import init_role
from .. import utils from .rpc_client import connect_to_server
SAMPLER_POOL = None SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = 0 NUM_SAMPLER_WORKERS = 0
...@@ -34,13 +34,22 @@ def get_sampler_pool(): ...@@ -34,13 +34,22 @@ def get_sampler_pool():
return SAMPLER_POOL, NUM_SAMPLER_WORKERS return SAMPLER_POOL, NUM_SAMPLER_WORKERS
def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads, group_id): def _init_rpc(
''' This init function is called in the worker processes. ip_config,
''' num_servers,
max_queue_size,
net_type,
role,
num_threads,
group_id,
):
"""This init function is called in the worker processes."""
try: try:
utils.set_num_threads(num_threads) utils.set_num_threads(num_threads)
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone': if os.environ.get("DGL_DIST_MODE", "standalone") != "standalone":
connect_to_server(ip_config, num_servers, max_queue_size, net_type, group_id) connect_to_server(
ip_config, num_servers, max_queue_size, net_type, group_id
)
init_role(role) init_role(role)
init_kvstore(ip_config, num_servers, role) init_kvstore(ip_config, num_servers, role)
except Exception as e: except Exception as e:
...@@ -51,6 +60,7 @@ def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_thread ...@@ -51,6 +60,7 @@ def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_thread
class MpCommand(Enum): class MpCommand(Enum):
"""Enum class for multiprocessing command""" """Enum class for multiprocessing command"""
INIT_RPC = 0 # Not used in the task queue INIT_RPC = 0 # Not used in the task queue
SET_COLLATE_FN = 1 SET_COLLATE_FN = 1
CALL_BARRIER = 2 CALL_BARRIER = 2
...@@ -80,12 +90,16 @@ def init_process(rpc_config, mp_contexts): ...@@ -80,12 +90,16 @@ def init_process(rpc_config, mp_contexts):
elif command == MpCommand.CALL_BARRIER: elif command == MpCommand.CALL_BARRIER:
barrier.wait() barrier.wait()
elif command == MpCommand.DELETE_COLLATE_FN: elif command == MpCommand.DELETE_COLLATE_FN:
dataloader_name, = args (dataloader_name,) = args
del collate_fn_dict[dataloader_name] del collate_fn_dict[dataloader_name]
elif command == MpCommand.CALL_COLLATE_FN: elif command == MpCommand.CALL_COLLATE_FN:
dataloader_name, collate_args = args dataloader_name, collate_args = args
data_queue.put( data_queue.put(
(dataloader_name, collate_fn_dict[dataloader_name](collate_args))) (
dataloader_name,
collate_fn_dict[dataloader_name](collate_args),
)
)
elif command == MpCommand.CALL_FN_ALL_WORKERS: elif command == MpCommand.CALL_FN_ALL_WORKERS:
func, func_args = args func, func_args = args
func(func_args) func(func_args)
...@@ -101,6 +115,7 @@ def init_process(rpc_config, mp_contexts): ...@@ -101,6 +115,7 @@ def init_process(rpc_config, mp_contexts):
class CustomPool: class CustomPool:
"""Customized worker pool""" """Customized worker pool"""
def __init__(self, num_workers, rpc_config): def __init__(self, num_workers, rpc_config):
""" """
Customized worker pool init function Customized worker pool init function
...@@ -117,8 +132,13 @@ class CustomPool: ...@@ -117,8 +132,13 @@ class CustomPool:
for _ in range(num_workers): for _ in range(num_workers):
task_queue = ctx.Queue(self.queue_size) task_queue = ctx.Queue(self.queue_size)
self.task_queues.append(task_queue) self.task_queues.append(task_queue)
proc = ctx.Process(target=init_process, args=( proc = ctx.Process(
rpc_config, (self.result_queue, task_queue, self.barrier))) target=init_process,
args=(
rpc_config,
(self.result_queue, task_queue, self.barrier),
),
)
proc.daemon = True proc.daemon = True
proc.start() proc.start()
self.process_list.append(proc) self.process_list.append(proc)
...@@ -127,20 +147,23 @@ class CustomPool: ...@@ -127,20 +147,23 @@ class CustomPool:
"""Set collate function in subprocess""" """Set collate function in subprocess"""
for i in range(self.num_workers): for i in range(self.num_workers):
self.task_queues[i].put( self.task_queues[i].put(
(MpCommand.SET_COLLATE_FN, (dataloader_name, func))) (MpCommand.SET_COLLATE_FN, (dataloader_name, func))
)
def submit_task(self, dataloader_name, args): def submit_task(self, dataloader_name, args):
"""Submit task to workers""" """Submit task to workers"""
# Round robin # Round robin
self.task_queues[self.current_proc_id].put( self.task_queues[self.current_proc_id].put(
(MpCommand.CALL_COLLATE_FN, (dataloader_name, args))) (MpCommand.CALL_COLLATE_FN, (dataloader_name, args))
)
self.current_proc_id = (self.current_proc_id + 1) % self.num_workers self.current_proc_id = (self.current_proc_id + 1) % self.num_workers
def submit_task_to_all_workers(self, func, args): def submit_task_to_all_workers(self, func, args):
"""Submit task to all workers""" """Submit task to all workers"""
for i in range(self.num_workers): for i in range(self.num_workers):
self.task_queues[i].put( self.task_queues[i].put(
(MpCommand.CALL_FN_ALL_WORKERS, (func, args))) (MpCommand.CALL_FN_ALL_WORKERS, (func, args))
)
def get_result(self, dataloader_name, timeout=1800): def get_result(self, dataloader_name, timeout=1800):
"""Get result from result queue""" """Get result from result queue"""
...@@ -152,20 +175,21 @@ class CustomPool: ...@@ -152,20 +175,21 @@ class CustomPool:
"""Delete collate function""" """Delete collate function"""
for i in range(self.num_workers): for i in range(self.num_workers):
self.task_queues[i].put( self.task_queues[i].put(
(MpCommand.DELETE_COLLATE_FN, (dataloader_name, ))) (MpCommand.DELETE_COLLATE_FN, (dataloader_name,))
)
def call_barrier(self): def call_barrier(self):
"""Call barrier at all workers""" """Call barrier at all workers"""
for i in range(self.num_workers): for i in range(self.num_workers):
self.task_queues[i].put( self.task_queues[i].put((MpCommand.CALL_BARRIER, tuple()))
(MpCommand.CALL_BARRIER, tuple()))
def close(self): def close(self):
"""Close worker pool""" """Close worker pool"""
for i in range(self.num_workers): for i in range(self.num_workers):
self.task_queues[i].put((MpCommand.FINALIZE_POOL, tuple()), block=False) self.task_queues[i].put(
time.sleep(0.5) # Fix for early python version (MpCommand.FINALIZE_POOL, tuple()), block=False
)
time.sleep(0.5) # Fix for early python version
def join(self): def join(self):
"""Join the close process of worker pool""" """Join the close process of worker pool"""
...@@ -173,8 +197,12 @@ class CustomPool: ...@@ -173,8 +197,12 @@ class CustomPool:
self.process_list[i].join() self.process_list[i].join()
def initialize(ip_config, max_queue_size=MAX_QUEUE_SIZE, def initialize(
net_type='socket', num_worker_threads=1): ip_config,
max_queue_size=MAX_QUEUE_SIZE,
net_type="socket",
num_worker_threads=1,
):
"""Initialize DGL's distributed module """Initialize DGL's distributed module
This function initializes DGL's distributed module. It acts differently in server This function initializes DGL's distributed module. It acts differently in server
...@@ -204,59 +232,84 @@ def initialize(ip_config, max_queue_size=MAX_QUEUE_SIZE, ...@@ -204,59 +232,84 @@ def initialize(ip_config, max_queue_size=MAX_QUEUE_SIZE,
distributed API. For example, when used with Pytorch, users have to invoke this function distributed API. For example, when used with Pytorch, users have to invoke this function
before Pytorch's `pytorch.distributed.init_process_group`. before Pytorch's `pytorch.distributed.init_process_group`.
""" """
if os.environ.get('DGL_ROLE', 'client') == 'server': if os.environ.get("DGL_ROLE", "client") == "server":
from .dist_graph import DistGraphServer from .dist_graph import DistGraphServer
assert os.environ.get('DGL_SERVER_ID') is not None, \
'Please define DGL_SERVER_ID to run DistGraph server' assert (
assert os.environ.get('DGL_IP_CONFIG') is not None, \ os.environ.get("DGL_SERVER_ID") is not None
'Please define DGL_IP_CONFIG to run DistGraph server' ), "Please define DGL_SERVER_ID to run DistGraph server"
assert os.environ.get('DGL_NUM_SERVER') is not None, \ assert (
'Please define DGL_NUM_SERVER to run DistGraph server' os.environ.get("DGL_IP_CONFIG") is not None
assert os.environ.get('DGL_NUM_CLIENT') is not None, \ ), "Please define DGL_IP_CONFIG to run DistGraph server"
'Please define DGL_NUM_CLIENT to run DistGraph server' assert (
assert os.environ.get('DGL_CONF_PATH') is not None, \ os.environ.get("DGL_NUM_SERVER") is not None
'Please define DGL_CONF_PATH to run DistGraph server' ), "Please define DGL_NUM_SERVER to run DistGraph server"
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',') assert (
os.environ.get("DGL_NUM_CLIENT") is not None
), "Please define DGL_NUM_CLIENT to run DistGraph server"
assert (
os.environ.get("DGL_CONF_PATH") is not None
), "Please define DGL_CONF_PATH to run DistGraph server"
formats = os.environ.get("DGL_GRAPH_FORMAT", "csc").split(",")
formats = [f.strip() for f in formats] formats = [f.strip() for f in formats]
rpc.reset() rpc.reset()
keep_alive = bool(int(os.environ.get('DGL_KEEP_ALIVE', 0))) keep_alive = bool(int(os.environ.get("DGL_KEEP_ALIVE", 0)))
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')), serv = DistGraphServer(
os.environ.get('DGL_IP_CONFIG'), int(os.environ.get("DGL_SERVER_ID")),
int(os.environ.get('DGL_NUM_SERVER')), os.environ.get("DGL_IP_CONFIG"),
int(os.environ.get('DGL_NUM_CLIENT')), int(os.environ.get("DGL_NUM_SERVER")),
os.environ.get('DGL_CONF_PATH'), int(os.environ.get("DGL_NUM_CLIENT")),
graph_format=formats, os.environ.get("DGL_CONF_PATH"),
keep_alive=keep_alive, graph_format=formats,
net_type=net_type) keep_alive=keep_alive,
net_type=net_type,
)
serv.start() serv.start()
sys.exit() sys.exit()
else: else:
num_workers = int(os.environ.get('DGL_NUM_SAMPLER', 0)) num_workers = int(os.environ.get("DGL_NUM_SAMPLER", 0))
num_servers = int(os.environ.get('DGL_NUM_SERVER', 1)) num_servers = int(os.environ.get("DGL_NUM_SERVER", 1))
group_id = int(os.environ.get('DGL_GROUP_ID', 0)) group_id = int(os.environ.get("DGL_GROUP_ID", 0))
rpc.reset() rpc.reset()
global SAMPLER_POOL global SAMPLER_POOL
global NUM_SAMPLER_WORKERS global NUM_SAMPLER_WORKERS
is_standalone = os.environ.get( is_standalone = (
'DGL_DIST_MODE', 'standalone') == 'standalone' os.environ.get("DGL_DIST_MODE", "standalone") == "standalone"
)
if num_workers > 0 and not is_standalone: if num_workers > 0 and not is_standalone:
SAMPLER_POOL = CustomPool(num_workers, (ip_config, num_servers, max_queue_size, SAMPLER_POOL = CustomPool(
net_type, 'sampler', num_worker_threads, num_workers,
group_id)) (
ip_config,
num_servers,
max_queue_size,
net_type,
"sampler",
num_worker_threads,
group_id,
),
)
else: else:
SAMPLER_POOL = None SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers NUM_SAMPLER_WORKERS = num_workers
if not is_standalone: if not is_standalone:
assert num_servers is not None and num_servers > 0, \ assert (
'The number of servers per machine must be specified with a positive number.' num_servers is not None and num_servers > 0
connect_to_server(ip_config, num_servers, max_queue_size, net_type, group_id=group_id) ), "The number of servers per machine must be specified with a positive number."
init_role('default') connect_to_server(
init_kvstore(ip_config, num_servers, 'default') ip_config,
num_servers,
max_queue_size,
net_type,
group_id=group_id,
)
init_role("default")
init_kvstore(ip_config, num_servers, "default")
def finalize_client(): def finalize_client():
"""Release resources of this client.""" """Release resources of this client."""
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone': if os.environ.get("DGL_DIST_MODE", "standalone") != "standalone":
rpc.finalize_sender() rpc.finalize_sender()
rpc.finalize_receiver() rpc.finalize_receiver()
...@@ -268,7 +321,7 @@ def _exit(): ...@@ -268,7 +321,7 @@ def _exit():
def finalize_worker(): def finalize_worker():
"""Finalize workers """Finalize workers
Python's multiprocessing pool will not call atexit function when close Python's multiprocessing pool will not call atexit function when close
""" """
global SAMPLER_POOL global SAMPLER_POOL
if SAMPLER_POOL is not None: if SAMPLER_POOL is not None:
...@@ -284,8 +337,7 @@ def join_finalize_worker(): ...@@ -284,8 +337,7 @@ def join_finalize_worker():
def is_initialized(): def is_initialized():
"""Is RPC initialized? """Is RPC initialized?"""
"""
return INITIALIZED return INITIALIZED
...@@ -297,6 +349,7 @@ def _shutdown_servers(): ...@@ -297,6 +349,7 @@ def _shutdown_servers():
for server_id in range(rpc.get_num_server()): for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req) rpc.send_request(server_id, req)
def exit_client(): def exit_client():
"""Trainer exits """Trainer exits
...@@ -308,12 +361,15 @@ def exit_client(): ...@@ -308,12 +361,15 @@ def exit_client():
needs to call `exit_client` before calling `initialize` again. needs to call `exit_client` before calling `initialize` again.
""" """
# Only client with rank_0 will send shutdown request to servers. # Only client with rank_0 will send shutdown request to servers.
print("Client[{}] in group[{}] is exiting...".format( print(
rpc.get_rank(), rpc.get_group_id())) "Client[{}] in group[{}] is exiting...".format(
rpc.get_rank(), rpc.get_group_id()
)
)
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
# collect data such as DistTensor before exit # collect data such as DistTensor before exit
gc.collect() gc.collect()
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone': if os.environ.get("DGL_DIST_MODE", "standalone") != "standalone":
rpc.client_barrier() rpc.client_barrier()
_shutdown_servers() _shutdown_servers()
finalize_client() finalize_client()
......
# pylint: disable=global-variable-undefined, invalid-name # pylint: disable=global-variable-undefined, invalid-name
"""Multiprocess dataloader for distributed training""" """Multiprocess dataloader for distributed training"""
from .dist_context import get_sampler_pool
from .. import backend as F from .. import backend as F
from .dist_context import get_sampler_pool
__all__ = ["DistDataLoader"] __all__ = ["DistDataLoader"]
...@@ -61,8 +61,15 @@ class DistDataLoader: ...@@ -61,8 +61,15 @@ class DistDataLoader:
and [3, 4] is not guaranteed. and [3, 4] is not guaranteed.
""" """
def __init__(self, dataset, batch_size, shuffle=False, collate_fn=None, drop_last=False, def __init__(
queue_size=None): self,
dataset,
batch_size,
shuffle=False,
collate_fn=None,
drop_last=False,
queue_size=None,
):
self.pool, self.num_workers = get_sampler_pool() self.pool, self.num_workers = get_sampler_pool()
if queue_size is None: if queue_size is None:
queue_size = self.num_workers * 4 if self.num_workers > 0 else 4 queue_size = self.num_workers * 4 if self.num_workers > 0 else 4
...@@ -71,7 +78,7 @@ class DistDataLoader: ...@@ -71,7 +78,7 @@ class DistDataLoader:
self.num_pending = 0 self.num_pending = 0
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.current_pos = 0 self.current_pos = 0
self.queue = [] # Only used when pool is None self.queue = [] # Only used when pool is None
self.drop_last = drop_last self.drop_last = drop_last
self.recv_idxs = 0 self.recv_idxs = 0
self.shuffle = shuffle self.shuffle = shuffle
...@@ -153,7 +160,7 @@ class DistDataLoader: ...@@ -153,7 +160,7 @@ class DistDataLoader:
end_pos = len(self.dataset) end_pos = len(self.dataset)
else: else:
end_pos = self.current_pos + self.batch_size end_pos = self.current_pos + self.batch_size
idx = self.data_idx[self.current_pos:end_pos].tolist() idx = self.data_idx[self.current_pos : end_pos].tolist()
ret = [self.dataset[i] for i in idx] ret = [self.dataset[i] for i in idx]
# Sharing large number of tensors between processes will consume too many # Sharing large number of tensors between processes will consume too many
# file descriptors, so let's convert each tensor to scalar value beforehand. # file descriptors, so let's convert each tensor to scalar value beforehand.
......
...@@ -3,17 +3,24 @@ ...@@ -3,17 +3,24 @@
import pickle import pickle
from abc import ABC from abc import ABC
from ast import literal_eval from ast import literal_eval
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F
from ..base import NID, EID, DGLError, dgl_warning
from .. import utils from .. import utils
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from ..base import EID, NID, DGLError, dgl_warning
from ..ndarray import exist_shared_mem_array from ..ndarray import exist_shared_mem_array
from ..partition import NDArrayPartition from ..partition import NDArrayPartition
from .constants import DEFAULT_ETYPE, DEFAULT_NTYPE
from .id_map import IdMap from .id_map import IdMap
from .constants import DEFAULT_NTYPE, DEFAULT_ETYPE from .shared_mem_utils import (
DTYPE_DICT,
_get_edata_path,
_get_ndata_path,
_to_shared_mem,
)
def _str_to_tuple(s): def _str_to_tuple(s):
try: try:
...@@ -22,9 +29,18 @@ def _str_to_tuple(s): ...@@ -22,9 +29,18 @@ def _str_to_tuple(s):
ret = s ret = s
return ret return ret
def _move_metadata_to_shared_mem(graph_name, num_nodes, num_edges, part_id,
num_partitions, node_map, edge_map, is_range_part): def _move_metadata_to_shared_mem(
''' Move all metadata of the partition book to the shared memory. graph_name,
num_nodes,
num_edges,
part_id,
num_partitions,
node_map,
edge_map,
is_range_part,
):
"""Move all metadata of the partition book to the shared memory.
These metadata will be used to construct graph partition book. These metadata will be used to construct graph partition book.
...@@ -56,17 +72,28 @@ def _move_metadata_to_shared_mem(graph_name, num_nodes, num_edges, part_id, ...@@ -56,17 +72,28 @@ def _move_metadata_to_shared_mem(graph_name, num_nodes, num_edges, part_id,
The first tensor stores the serialized metadata, the second tensor stores the serialized The first tensor stores the serialized metadata, the second tensor stores the serialized
node map and the third tensor stores the serialized edge map. All tensors are stored in node map and the third tensor stores the serialized edge map. All tensors are stored in
shared memory. shared memory.
''' """
meta = _to_shared_mem(F.tensor([int(is_range_part), num_nodes, num_edges, meta = _to_shared_mem(
num_partitions, part_id, F.tensor(
len(node_map), len(edge_map)]), [
_get_ndata_path(graph_name, 'meta')) int(is_range_part),
node_map = _to_shared_mem(node_map, _get_ndata_path(graph_name, 'node_map')) num_nodes,
edge_map = _to_shared_mem(edge_map, _get_edata_path(graph_name, 'edge_map')) num_edges,
num_partitions,
part_id,
len(node_map),
len(edge_map),
]
),
_get_ndata_path(graph_name, "meta"),
)
node_map = _to_shared_mem(node_map, _get_ndata_path(graph_name, "node_map"))
edge_map = _to_shared_mem(edge_map, _get_edata_path(graph_name, "edge_map"))
return meta, node_map, edge_map return meta, node_map, edge_map
def _get_shared_mem_metadata(graph_name): def _get_shared_mem_metadata(graph_name):
''' Get the metadata of the graph from shared memory. """Get the metadata of the graph from shared memory.
The server serializes the metadata of a graph and store them in shared memory. The server serializes the metadata of a graph and store them in shared memory.
The client needs to deserialize the data in shared memory and get the metadata The client needs to deserialize the data in shared memory and get the metadata
...@@ -85,24 +112,38 @@ def _get_shared_mem_metadata(graph_name): ...@@ -85,24 +112,38 @@ def _get_shared_mem_metadata(graph_name):
the third element is the number of partitions; the third element is the number of partitions;
the fourth element is the tensor that stores the serialized result of node maps; the fourth element is the tensor that stores the serialized result of node maps;
the fifth element is the tensor that stores the serialized result of edge maps. the fifth element is the tensor that stores the serialized result of edge maps.
''' """
# The metadata has 7 elements: is_range_part, num_nodes, num_edges, num_partitions, part_id, # The metadata has 7 elements: is_range_part, num_nodes, num_edges, num_partitions, part_id,
# the length of node map and the length of the edge map. # the length of node map and the length of the edge map.
shape = (7,) shape = (7,)
dtype = F.int64 dtype = F.int64
dtype = DTYPE_DICT[dtype] dtype = DTYPE_DICT[dtype]
data = empty_shared_mem(_get_ndata_path(graph_name, 'meta'), False, shape, dtype) data = empty_shared_mem(
_get_ndata_path(graph_name, "meta"), False, shape, dtype
)
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
meta = F.asnumpy(F.zerocopy_from_dlpack(dlpack)) meta = F.asnumpy(F.zerocopy_from_dlpack(dlpack))
is_range_part, _, _, num_partitions, part_id, node_map_len, edge_map_len = meta (
is_range_part,
_,
_,
num_partitions,
part_id,
node_map_len,
edge_map_len,
) = meta
# Load node map # Load node map
data = empty_shared_mem(_get_ndata_path(graph_name, 'node_map'), False, (node_map_len,), dtype) data = empty_shared_mem(
_get_ndata_path(graph_name, "node_map"), False, (node_map_len,), dtype
)
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
node_map = F.zerocopy_from_dlpack(dlpack) node_map = F.zerocopy_from_dlpack(dlpack)
# Load edge_map # Load edge_map
data = empty_shared_mem(_get_edata_path(graph_name, 'edge_map'), False, (edge_map_len,), dtype) data = empty_shared_mem(
_get_edata_path(graph_name, "edge_map"), False, (edge_map_len,), dtype
)
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
edge_map = F.zerocopy_from_dlpack(dlpack) edge_map = F.zerocopy_from_dlpack(dlpack)
...@@ -110,7 +151,7 @@ def _get_shared_mem_metadata(graph_name): ...@@ -110,7 +151,7 @@ def _get_shared_mem_metadata(graph_name):
def get_shared_mem_partition_book(graph_name, graph_part): def get_shared_mem_partition_book(graph_name, graph_part):
'''Get a graph partition book from shared memory. """Get a graph partition book from shared memory.
A graph partition book of a specific graph can be serialized to shared memory. A graph partition book of a specific graph can be serialized to shared memory.
We can reconstruct a graph partition book from shared memory. We can reconstruct a graph partition book from shared memory.
...@@ -126,11 +167,16 @@ def get_shared_mem_partition_book(graph_name, graph_part): ...@@ -126,11 +167,16 @@ def get_shared_mem_partition_book(graph_name, graph_part):
------- -------
GraphPartitionBook GraphPartitionBook
A graph partition book for a particular partition. A graph partition book for a particular partition.
''' """
if not exist_shared_mem_array(_get_ndata_path(graph_name, 'meta')): if not exist_shared_mem_array(_get_ndata_path(graph_name, "meta")):
return None return None
is_range_part, part_id, num_parts, node_map_data, edge_map_data = \ (
_get_shared_mem_metadata(graph_name) is_range_part,
part_id,
num_parts,
node_map_data,
edge_map_data,
) = _get_shared_mem_metadata(graph_name)
if is_range_part == 1: if is_range_part == 1:
# node ID ranges and edge ID ranges are stored in the order of node type IDs # node ID ranges and edge ID ranges are stored in the order of node type IDs
# and edge type IDs. # and edge type IDs.
...@@ -150,12 +196,17 @@ def get_shared_mem_partition_book(graph_name, graph_part): ...@@ -150,12 +196,17 @@ def get_shared_mem_partition_book(graph_name, graph_part):
for i, (etype, eid_range) in enumerate(edge_map_data): for i, (etype, eid_range) in enumerate(edge_map_data):
etypes[etype] = i etypes[etype] = i
edge_map[etype] = eid_range edge_map[etype] = eid_range
return RangePartitionBook(part_id, num_parts, node_map, edge_map, ntypes, etypes) return RangePartitionBook(
part_id, num_parts, node_map, edge_map, ntypes, etypes
)
else: else:
return BasicPartitionBook(part_id, num_parts, node_map_data, edge_map_data, graph_part) return BasicPartitionBook(
part_id, num_parts, node_map_data, edge_map_data, graph_part
)
def get_node_partition_from_book(book, device): def get_node_partition_from_book(book, device):
""" Get an NDArrayPartition of the nodes from a RangePartitionBook. """Get an NDArrayPartition of the nodes from a RangePartitionBook.
Parameters Parameters
---------- ----------
...@@ -169,25 +220,27 @@ def get_node_partition_from_book(book, device): ...@@ -169,25 +220,27 @@ def get_node_partition_from_book(book, device):
NDarrayPartition NDarrayPartition
The NDArrayPartition object for the nodes in the graph. The NDArrayPartition object for the nodes in the graph.
""" """
assert isinstance(book, RangePartitionBook), "Can only convert " \ assert isinstance(book, RangePartitionBook), (
"RangePartitionBook to NDArrayPartition." "Can only convert " "RangePartitionBook to NDArrayPartition."
)
# create prefix-sum array on host # create prefix-sum array on host
max_node_ids = F.zerocopy_from_numpy(book._max_node_ids) max_node_ids = F.zerocopy_from_numpy(book._max_node_ids)
cpu_range = F.cat([F.tensor([0], dtype=F.dtype(max_node_ids)), cpu_range = F.cat(
max_node_ids+1], dim=0) [F.tensor([0], dtype=F.dtype(max_node_ids)), max_node_ids + 1], dim=0
)
gpu_range = F.copy_to(cpu_range, ctx=device) gpu_range = F.copy_to(cpu_range, ctx=device)
# convert from numpy # convert from numpy
array_size = int(F.as_scalar(cpu_range[-1])) array_size = int(F.as_scalar(cpu_range[-1]))
num_parts = book.num_partitions() num_parts = book.num_partitions()
return NDArrayPartition(array_size, return NDArrayPartition(
num_parts, array_size, num_parts, mode="range", part_ranges=gpu_range
mode='range', )
part_ranges=gpu_range)
class GraphPartitionBook(ABC): class GraphPartitionBook(ABC):
""" The base class of the graph partition book. """The base class of the graph partition book.
For distributed training, a graph is partitioned into multiple parts and is loaded For distributed training, a graph is partitioned into multiple parts and is loaded
in multiple machines. The partition book contains all necessary information to locate in multiple machines. The partition book contains all necessary information to locate
...@@ -368,13 +421,11 @@ class GraphPartitionBook(ABC): ...@@ -368,13 +421,11 @@ class GraphPartitionBook(ABC):
@property @property
def ntypes(self): def ntypes(self):
"""Get the list of node types """Get the list of node types"""
"""
@property @property
def etypes(self): def etypes(self):
"""Get the list of edge types """Get the list of edge types"""
"""
@property @property
def canonical_etypes(self): def canonical_etypes(self):
...@@ -388,9 +439,8 @@ class GraphPartitionBook(ABC): ...@@ -388,9 +439,8 @@ class GraphPartitionBook(ABC):
@property @property
def is_homogeneous(self): def is_homogeneous(self):
"""check if homogeneous """check if homogeneous"""
""" return not (len(self.etypes) > 1 or len(self.ntypes) > 1)
return not(len(self.etypes) > 1 or len(self.ntypes) > 1)
def map_to_per_ntype(self, ids): def map_to_per_ntype(self, ids):
"""Map homogeneous node IDs to type-wise IDs and node types. """Map homogeneous node IDs to type-wise IDs and node types.
...@@ -452,6 +502,7 @@ class GraphPartitionBook(ABC): ...@@ -452,6 +502,7 @@ class GraphPartitionBook(ABC):
Homogeneous edge IDs. Homogeneous edge IDs.
""" """
class BasicPartitionBook(GraphPartitionBook): class BasicPartitionBook(GraphPartitionBook):
"""This provides the most flexible way to store parition information. """This provides the most flexible way to store parition information.
...@@ -473,33 +524,40 @@ class BasicPartitionBook(GraphPartitionBook): ...@@ -473,33 +524,40 @@ class BasicPartitionBook(GraphPartitionBook):
part_graph : DGLGraph part_graph : DGLGraph
The graph partition structure. The graph partition structure.
""" """
def __init__(self, part_id, num_parts, node_map, edge_map, part_graph): def __init__(self, part_id, num_parts, node_map, edge_map, part_graph):
assert part_id >= 0, 'part_id cannot be a negative number.' assert part_id >= 0, "part_id cannot be a negative number."
assert num_parts > 0, 'num_parts must be greater than zero.' assert num_parts > 0, "num_parts must be greater than zero."
self._part_id = int(part_id) self._part_id = int(part_id)
self._num_partitions = int(num_parts) self._num_partitions = int(num_parts)
self._nid2partid = F.tensor(node_map) self._nid2partid = F.tensor(node_map)
assert F.dtype(self._nid2partid) == F.int64, \ assert (
'the node map must be stored in an integer array' F.dtype(self._nid2partid) == F.int64
), "the node map must be stored in an integer array"
self._eid2partid = F.tensor(edge_map) self._eid2partid = F.tensor(edge_map)
assert F.dtype(self._eid2partid) == F.int64, \ assert (
'the edge map must be stored in an integer array' F.dtype(self._eid2partid) == F.int64
), "the edge map must be stored in an integer array"
# Get meta data of the partition book. # Get meta data of the partition book.
self._partition_meta_data = [] self._partition_meta_data = []
_, nid_count = np.unique(F.asnumpy(self._nid2partid), return_counts=True) _, nid_count = np.unique(
_, eid_count = np.unique(F.asnumpy(self._eid2partid), return_counts=True) F.asnumpy(self._nid2partid), return_counts=True
)
_, eid_count = np.unique(
F.asnumpy(self._eid2partid), return_counts=True
)
for partid in range(self._num_partitions): for partid in range(self._num_partitions):
part_info = {} part_info = {}
part_info['machine_id'] = partid part_info["machine_id"] = partid
part_info['num_nodes'] = int(nid_count[partid]) part_info["num_nodes"] = int(nid_count[partid])
part_info['num_edges'] = int(eid_count[partid]) part_info["num_edges"] = int(eid_count[partid])
self._partition_meta_data.append(part_info) self._partition_meta_data.append(part_info)
# Get partid2nids # Get partid2nids
self._partid2nids = [] self._partid2nids = []
sorted_nid = F.tensor(np.argsort(F.asnumpy(self._nid2partid))) sorted_nid = F.tensor(np.argsort(F.asnumpy(self._nid2partid)))
start = 0 start = 0
for offset in nid_count: for offset in nid_count:
part_nids = sorted_nid[start:start+offset] part_nids = sorted_nid[start : start + offset]
start += offset start += offset
self._partid2nids.append(part_nids) self._partid2nids.append(part_nids)
# Get partid2eids # Get partid2eids
...@@ -507,7 +565,7 @@ class BasicPartitionBook(GraphPartitionBook): ...@@ -507,7 +565,7 @@ class BasicPartitionBook(GraphPartitionBook):
sorted_eid = F.tensor(np.argsort(F.asnumpy(self._eid2partid))) sorted_eid = F.tensor(np.argsort(F.asnumpy(self._eid2partid)))
start = 0 start = 0
for offset in eid_count: for offset in eid_count:
part_eids = sorted_eid[start:start+offset] part_eids = sorted_eid[start : start + offset]
start += offset start += offset
self._partid2eids.append(part_eids) self._partid2eids.append(part_eids)
# Get nidg2l # Get nidg2l
...@@ -515,7 +573,7 @@ class BasicPartitionBook(GraphPartitionBook): ...@@ -515,7 +573,7 @@ class BasicPartitionBook(GraphPartitionBook):
global_id = part_graph.ndata[NID] global_id = part_graph.ndata[NID]
max_global_id = np.amax(F.asnumpy(global_id)) max_global_id = np.amax(F.asnumpy(global_id))
# TODO(chao): support int32 index # TODO(chao): support int32 index
g2l = F.zeros((max_global_id+1), F.int64, F.context(global_id)) g2l = F.zeros((max_global_id + 1), F.int64, F.context(global_id))
g2l = F.scatter_row(g2l, global_id, F.arange(0, len(global_id))) g2l = F.scatter_row(g2l, global_id, F.arange(0, len(global_id)))
self._nidg2l[self._part_id] = g2l self._nidg2l[self._part_id] = g2l
# Get eidg2l # Get eidg2l
...@@ -523,7 +581,7 @@ class BasicPartitionBook(GraphPartitionBook): ...@@ -523,7 +581,7 @@ class BasicPartitionBook(GraphPartitionBook):
global_id = part_graph.edata[EID] global_id = part_graph.edata[EID]
max_global_id = np.amax(F.asnumpy(global_id)) max_global_id = np.amax(F.asnumpy(global_id))
# TODO(chao): support int32 index # TODO(chao): support int32 index
g2l = F.zeros((max_global_id+1), F.int64, F.context(global_id)) g2l = F.zeros((max_global_id + 1), F.int64, F.context(global_id))
g2l = F.scatter_row(g2l, global_id, F.arange(0, len(global_id))) g2l = F.scatter_row(g2l, global_id, F.arange(0, len(global_id)))
self._eidg2l[self._part_id] = g2l self._eidg2l[self._part_id] = g2l
# node size and edge size # node size and edge size
...@@ -531,33 +589,43 @@ class BasicPartitionBook(GraphPartitionBook): ...@@ -531,33 +589,43 @@ class BasicPartitionBook(GraphPartitionBook):
self._node_size = len(self.partid2nids(self._part_id)) self._node_size = len(self.partid2nids(self._part_id))
def shared_memory(self, graph_name): def shared_memory(self, graph_name):
"""Move data to shared memory. """Move data to shared memory."""
""" (
self._meta, self._nid2partid, self._eid2partid = _move_metadata_to_shared_mem( self._meta,
graph_name, self._num_nodes(), self._num_edges(), self._part_id, self._num_partitions, self._nid2partid,
self._nid2partid, self._eid2partid, False) self._eid2partid,
) = _move_metadata_to_shared_mem(
graph_name,
self._num_nodes(),
self._num_edges(),
self._part_id,
self._num_partitions,
self._nid2partid,
self._eid2partid,
False,
)
def num_partitions(self): def num_partitions(self):
"""Return the number of partitions. """Return the number of partitions."""
"""
return self._num_partitions return self._num_partitions
def metadata(self): def metadata(self):
"""Return the partition meta data. """Return the partition meta data."""
"""
return self._partition_meta_data return self._partition_meta_data
def _num_nodes(self, ntype=DEFAULT_NTYPE): def _num_nodes(self, ntype=DEFAULT_NTYPE):
""" The total number of nodes """The total number of nodes"""
""" assert (
assert ntype == DEFAULT_NTYPE, 'Base partition book only supports homogeneous graph.' ntype == DEFAULT_NTYPE
), "Base partition book only supports homogeneous graph."
return len(self._nid2partid) return len(self._nid2partid)
def _num_edges(self, etype=DEFAULT_ETYPE): def _num_edges(self, etype=DEFAULT_ETYPE):
""" The total number of edges """The total number of edges"""
""" assert etype in (
assert etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]), \ DEFAULT_ETYPE,
'Base partition book only supports homogeneous graph.' DEFAULT_ETYPE[1],
), "Base partition book only supports homogeneous graph."
return len(self._eid2partid) return len(self._eid2partid)
def map_to_per_ntype(self, ids): def map_to_per_ntype(self, ids):
...@@ -575,79 +643,88 @@ class BasicPartitionBook(GraphPartitionBook): ...@@ -575,79 +643,88 @@ class BasicPartitionBook(GraphPartitionBook):
return F.zeros((len(ids),), F.int32, F.cpu()), ids return F.zeros((len(ids),), F.int32, F.cpu()), ids
def map_to_homo_nid(self, ids, ntype=DEFAULT_NTYPE): def map_to_homo_nid(self, ids, ntype=DEFAULT_NTYPE):
"""Map per-node-type IDs to global node IDs in the homogeneous format. """Map per-node-type IDs to global node IDs in the homogeneous format."""
""" assert (
assert ntype == DEFAULT_NTYPE, 'Base partition book only supports homogeneous graph.' ntype == DEFAULT_NTYPE
), "Base partition book only supports homogeneous graph."
return ids return ids
def map_to_homo_eid(self, ids, etype=DEFAULT_ETYPE): def map_to_homo_eid(self, ids, etype=DEFAULT_ETYPE):
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format. """Map per-edge-type IDs to global edge IDs in the homoenegeous format."""
""" assert etype in (
assert etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]), \ DEFAULT_ETYPE,
'Base partition book only supports homogeneous graph.' DEFAULT_ETYPE[1],
), "Base partition book only supports homogeneous graph."
return ids return ids
def nid2partid(self, nids, ntype=DEFAULT_NTYPE): def nid2partid(self, nids, ntype=DEFAULT_NTYPE):
"""From global node IDs to partition IDs """From global node IDs to partition IDs"""
""" assert (
assert ntype == DEFAULT_NTYPE, 'Base partition book only supports homogeneous graph.' ntype == DEFAULT_NTYPE
), "Base partition book only supports homogeneous graph."
return F.gather_row(self._nid2partid, nids) return F.gather_row(self._nid2partid, nids)
def eid2partid(self, eids, etype=DEFAULT_ETYPE): def eid2partid(self, eids, etype=DEFAULT_ETYPE):
"""From global edge IDs to partition IDs """From global edge IDs to partition IDs"""
""" assert etype in (
assert etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]), \ DEFAULT_ETYPE,
'Base partition book only supports homogeneous graph.' DEFAULT_ETYPE[1],
), "Base partition book only supports homogeneous graph."
return F.gather_row(self._eid2partid, eids) return F.gather_row(self._eid2partid, eids)
def partid2nids(self, partid, ntype=DEFAULT_NTYPE): def partid2nids(self, partid, ntype=DEFAULT_NTYPE):
"""From partition id to global node IDs """From partition id to global node IDs"""
""" assert (
assert ntype == DEFAULT_NTYPE, 'Base partition book only supports homogeneous graph.' ntype == DEFAULT_NTYPE
), "Base partition book only supports homogeneous graph."
return self._partid2nids[partid] return self._partid2nids[partid]
def partid2eids(self, partid, etype=DEFAULT_ETYPE): def partid2eids(self, partid, etype=DEFAULT_ETYPE):
"""From partition id to global edge IDs """From partition id to global edge IDs"""
""" assert etype in (
assert etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]), \ DEFAULT_ETYPE,
'Base partition book only supports homogeneous graph.' DEFAULT_ETYPE[1],
), "Base partition book only supports homogeneous graph."
return self._partid2eids[partid] return self._partid2eids[partid]
def nid2localnid(self, nids, partid, ntype=DEFAULT_NTYPE): def nid2localnid(self, nids, partid, ntype=DEFAULT_NTYPE):
"""Get local node IDs within the given partition. """Get local node IDs within the given partition."""
""" assert (
assert ntype == DEFAULT_NTYPE, 'Base partition book only supports homogeneous graph.' ntype == DEFAULT_NTYPE
), "Base partition book only supports homogeneous graph."
if partid != self._part_id: if partid != self._part_id:
raise RuntimeError('Now GraphPartitionBook does not support \ raise RuntimeError(
getting remote tensor of nid2localnid.') "Now GraphPartitionBook does not support \
getting remote tensor of nid2localnid."
)
return F.gather_row(self._nidg2l[partid], nids) return F.gather_row(self._nidg2l[partid], nids)
def eid2localeid(self, eids, partid, etype=DEFAULT_ETYPE): def eid2localeid(self, eids, partid, etype=DEFAULT_ETYPE):
"""Get the local edge ids within the given partition. """Get the local edge ids within the given partition."""
""" assert etype in (
assert etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]), \ DEFAULT_ETYPE,
'Base partition book only supports homogeneous graph.' DEFAULT_ETYPE[1],
), "Base partition book only supports homogeneous graph."
if partid != self._part_id: if partid != self._part_id:
raise RuntimeError('Now GraphPartitionBook does not support \ raise RuntimeError(
getting remote tensor of eid2localeid.') "Now GraphPartitionBook does not support \
getting remote tensor of eid2localeid."
)
return F.gather_row(self._eidg2l[partid], eids) return F.gather_row(self._eidg2l[partid], eids)
@property @property
def partid(self): def partid(self):
"""Get the current partition ID """Get the current partition ID"""
"""
return self._part_id return self._part_id
@property @property
def ntypes(self): def ntypes(self):
"""Get the list of node types """Get the list of node types"""
"""
return [DEFAULT_NTYPE] return [DEFAULT_NTYPE]
@property @property
def etypes(self): def etypes(self):
"""Get the list of edge types """Get the list of edge types"""
"""
return [DEFAULT_ETYPE[1]] return [DEFAULT_ETYPE[1]]
@property @property
...@@ -698,9 +775,10 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -698,9 +775,10 @@ class RangePartitionBook(GraphPartitionBook):
Single string format for keys of ``edge_map`` and ``etypes`` is deprecated. Single string format for keys of ``edge_map`` and ``etypes`` is deprecated.
``(str, str, str)`` will be the only format supported in the future. ``(str, str, str)`` will be the only format supported in the future.
""" """
def __init__(self, part_id, num_parts, node_map, edge_map, ntypes, etypes): def __init__(self, part_id, num_parts, node_map, edge_map, ntypes, etypes):
assert part_id >= 0, 'part_id cannot be a negative number.' assert part_id >= 0, "part_id cannot be a negative number."
assert num_parts > 0, 'num_parts must be greater than zero.' assert num_parts > 0, "num_parts must be greater than zero."
self._partid = part_id self._partid = part_id
self._num_partitions = num_parts self._num_partitions = num_parts
self._ntypes = [None] * len(ntypes) self._ntypes = [None] * len(ntypes)
...@@ -711,12 +789,14 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -711,12 +789,14 @@ class RangePartitionBook(GraphPartitionBook):
for ntype in ntypes: for ntype in ntypes:
ntype_id = ntypes[ntype] ntype_id = ntypes[ntype]
self._ntypes[ntype_id] = ntype self._ntypes[ntype_id] = ntype
assert all(ntype is not None for ntype in self._ntypes), \ assert all(
"The node types have invalid IDs." ntype is not None for ntype in self._ntypes
), "The node types have invalid IDs."
for etype, etype_id in etypes.items(): for etype, etype_id in etypes.items():
if isinstance(etype, tuple): if isinstance(etype, tuple):
assert len(etype) == 3, \ assert (
'Canonical etype should be in format of (str, str, str).' len(etype) == 3
), "Canonical etype should be in format of (str, str, str)."
c_etype = etype c_etype = etype
etype = etype[1] etype = etype[1]
self._etypes[etype_id] = etype self._etypes[etype_id] = etype
...@@ -730,11 +810,13 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -730,11 +810,13 @@ class RangePartitionBook(GraphPartitionBook):
self._etype2canonical[etype] = c_etype self._etype2canonical[etype] = c_etype
else: else:
dgl_warning( dgl_warning(
"Etype with 'str' format is deprecated. Please use '(str, str, str)'.") "Etype with 'str' format is deprecated. Please use '(str, str, str)'."
)
self._etypes[etype_id] = etype self._etypes[etype_id] = etype
self._canonical_etypes[etype_id] = None self._canonical_etypes[etype_id] = None
assert all(etype is not None for etype in self._etypes), \ assert all(
"The edge types have invalid IDs." etype is not None for etype in self._etypes
), "The edge types have invalid IDs."
# This stores the node ID ranges for each node type in each partition. # This stores the node ID ranges for each node type in each partition.
# The key is the node type, the value is a NumPy matrix with two columns, in which # The key is the node type, the value is a NumPy matrix with two columns, in which
...@@ -747,16 +829,20 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -747,16 +829,20 @@ class RangePartitionBook(GraphPartitionBook):
self._typed_max_node_ids = {} self._typed_max_node_ids = {}
max_node_map = np.zeros((num_parts,), dtype=np.int64) max_node_map = np.zeros((num_parts,), dtype=np.int64)
for key in node_map: for key in node_map:
assert key in ntypes, 'Unexpected ntype: {}.'.format(key) assert key in ntypes, "Unexpected ntype: {}.".format(key)
if not isinstance(node_map[key], np.ndarray): if not isinstance(node_map[key], np.ndarray):
node_map[key] = F.asnumpy(node_map[key]) node_map[key] = F.asnumpy(node_map[key])
assert node_map[key].shape == (num_parts, 2) assert node_map[key].shape == (num_parts, 2)
self._typed_nid_range[key] = node_map[key] self._typed_nid_range[key] = node_map[key]
# This is used for per-node-type lookup. # This is used for per-node-type lookup.
self._typed_max_node_ids[key] = np.cumsum(self._typed_nid_range[key][:, 1] self._typed_max_node_ids[key] = np.cumsum(
- self._typed_nid_range[key][:, 0]) self._typed_nid_range[key][:, 1]
- self._typed_nid_range[key][:, 0]
)
# This is used for homogeneous node ID lookup. # This is used for homogeneous node ID lookup.
max_node_map = np.maximum(self._typed_nid_range[key][:, 1], max_node_map) max_node_map = np.maximum(
self._typed_nid_range[key][:, 1], max_node_map
)
# This is a vector that indicates the last node ID in each partition. # This is a vector that indicates the last node ID in each partition.
# The ID is the global ID in the homogeneous representation. # The ID is the global ID in the homogeneous representation.
self._max_node_ids = max_node_map self._max_node_ids = max_node_map
...@@ -767,16 +853,20 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -767,16 +853,20 @@ class RangePartitionBook(GraphPartitionBook):
self._typed_max_edge_ids = {} self._typed_max_edge_ids = {}
max_edge_map = np.zeros((num_parts,), dtype=np.int64) max_edge_map = np.zeros((num_parts,), dtype=np.int64)
for key in edge_map: for key in edge_map:
assert key in etypes, 'Unexpected etype: {}.'.format(key) assert key in etypes, "Unexpected etype: {}.".format(key)
if not isinstance(edge_map[key], np.ndarray): if not isinstance(edge_map[key], np.ndarray):
edge_map[key] = F.asnumpy(edge_map[key]) edge_map[key] = F.asnumpy(edge_map[key])
assert edge_map[key].shape == (num_parts, 2) assert edge_map[key].shape == (num_parts, 2)
self._typed_eid_range[key] = edge_map[key] self._typed_eid_range[key] = edge_map[key]
# This is used for per-edge-type lookup. # This is used for per-edge-type lookup.
self._typed_max_edge_ids[key] = np.cumsum(self._typed_eid_range[key][:, 1] self._typed_max_edge_ids[key] = np.cumsum(
- self._typed_eid_range[key][:, 0]) self._typed_eid_range[key][:, 1]
- self._typed_eid_range[key][:, 0]
)
# This is used for homogeneous edge ID lookup. # This is used for homogeneous edge ID lookup.
max_edge_map = np.maximum(self._typed_eid_range[key][:, 1], max_edge_map) max_edge_map = np.maximum(
self._typed_eid_range[key][:, 1], max_edge_map
)
# Similar to _max_node_ids # Similar to _max_node_ids
self._max_edge_ids = max_edge_map self._max_edge_ids = max_edge_map
...@@ -796,14 +886,13 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -796,14 +886,13 @@ class RangePartitionBook(GraphPartitionBook):
num_edges = erange_end - erange_start num_edges = erange_end - erange_start
part_info = {} part_info = {}
part_info['machine_id'] = partid part_info["machine_id"] = partid
part_info['num_nodes'] = int(num_nodes) part_info["num_nodes"] = int(num_nodes)
part_info['num_edges'] = int(num_edges) part_info["num_edges"] = int(num_edges)
self._partition_meta_data.append(part_info) self._partition_meta_data.append(part_info)
def shared_memory(self, graph_name): def shared_memory(self, graph_name):
"""Move data to shared memory. """Move data to shared memory."""
"""
# we need to store the nid ranges and eid ranges of different types in the order defined # we need to store the nid ranges and eid ranges of different types in the order defined
# by type IDs. # by type IDs.
nid_range = [None] * len(self.ntypes) nid_range = [None] * len(self.ntypes)
...@@ -817,31 +906,30 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -817,31 +906,30 @@ class RangePartitionBook(GraphPartitionBook):
eid_range[i] = (c_etype, self._typed_eid_range[c_etype]) eid_range[i] = (c_etype, self._typed_eid_range[c_etype])
eid_range_pickle = list(pickle.dumps(eid_range)) eid_range_pickle = list(pickle.dumps(eid_range))
self._meta = _move_metadata_to_shared_mem(graph_name, self._meta = _move_metadata_to_shared_mem(
0, # We don't need to provide the number of nodes graph_name,
0, # We don't need to provide the number of edges 0, # We don't need to provide the number of nodes
self._partid, self._num_partitions, 0, # We don't need to provide the number of edges
F.tensor(nid_range_pickle), self._partid,
F.tensor(eid_range_pickle), self._num_partitions,
True) F.tensor(nid_range_pickle),
F.tensor(eid_range_pickle),
True,
)
def num_partitions(self): def num_partitions(self):
"""Return the number of partitions. """Return the number of partitions."""
"""
return self._num_partitions return self._num_partitions
def _num_nodes(self, ntype=DEFAULT_NTYPE): def _num_nodes(self, ntype=DEFAULT_NTYPE):
""" The total number of nodes """The total number of nodes"""
"""
if ntype == DEFAULT_NTYPE: if ntype == DEFAULT_NTYPE:
return int(self._max_node_ids[-1]) return int(self._max_node_ids[-1])
else: else:
return int(self._typed_max_node_ids[ntype][-1]) return int(self._typed_max_node_ids[ntype][-1])
def _num_edges(self, etype=DEFAULT_ETYPE): def _num_edges(self, etype=DEFAULT_ETYPE):
""" The total number of edges """The total number of edges"""
"""
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]): if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
return int(self._max_edge_ids[-1]) return int(self._max_edge_ids[-1])
else: else:
...@@ -849,8 +937,7 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -849,8 +937,7 @@ class RangePartitionBook(GraphPartitionBook):
return int(self._typed_max_edge_ids[c_etype][-1]) return int(self._typed_max_edge_ids[c_etype][-1])
def metadata(self): def metadata(self):
"""Return the partition meta data. """Return the partition meta data."""
"""
return self._partition_meta_data return self._partition_meta_data
def map_to_per_ntype(self, ids): def map_to_per_ntype(self, ids):
...@@ -868,67 +955,75 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -868,67 +955,75 @@ class RangePartitionBook(GraphPartitionBook):
return self._eid_map(ids) return self._eid_map(ids)
def map_to_homo_nid(self, ids, ntype): def map_to_homo_nid(self, ids, ntype):
"""Map per-node-type IDs to global node IDs in the homogeneous format. """Map per-node-type IDs to global node IDs in the homogeneous format."""
"""
ids = utils.toindex(ids).tousertensor() ids = utils.toindex(ids).tousertensor()
partids = self.nid2partid(ids, ntype) partids = self.nid2partid(ids, ntype)
typed_max_nids = F.zerocopy_from_numpy(self._typed_max_node_ids[ntype]) typed_max_nids = F.zerocopy_from_numpy(self._typed_max_node_ids[ntype])
end_diff = F.gather_row(typed_max_nids, partids) - ids end_diff = F.gather_row(typed_max_nids, partids) - ids
typed_nid_range = F.zerocopy_from_numpy(self._typed_nid_range[ntype][:, 1]) typed_nid_range = F.zerocopy_from_numpy(
self._typed_nid_range[ntype][:, 1]
)
return F.gather_row(typed_nid_range, partids) - end_diff return F.gather_row(typed_nid_range, partids) - end_diff
def map_to_homo_eid(self, ids, etype): def map_to_homo_eid(self, ids, etype):
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format. """Map per-edge-type IDs to global edge IDs in the homoenegeous format."""
"""
ids = utils.toindex(ids).tousertensor() ids = utils.toindex(ids).tousertensor()
c_etype = self._to_canonical_etype(etype) c_etype = self._to_canonical_etype(etype)
partids = self.eid2partid(ids, c_etype) partids = self.eid2partid(ids, c_etype)
typed_max_eids = F.zerocopy_from_numpy(self._typed_max_edge_ids[c_etype]) typed_max_eids = F.zerocopy_from_numpy(
self._typed_max_edge_ids[c_etype]
)
end_diff = F.gather_row(typed_max_eids, partids) - ids end_diff = F.gather_row(typed_max_eids, partids) - ids
typed_eid_range = F.zerocopy_from_numpy(self._typed_eid_range[c_etype][:, 1]) typed_eid_range = F.zerocopy_from_numpy(
self._typed_eid_range[c_etype][:, 1]
)
return F.gather_row(typed_eid_range, partids) - end_diff return F.gather_row(typed_eid_range, partids) - end_diff
def nid2partid(self, nids, ntype=DEFAULT_NTYPE): def nid2partid(self, nids, ntype=DEFAULT_NTYPE):
"""From global node IDs to partition IDs """From global node IDs to partition IDs"""
"""
nids = utils.toindex(nids) nids = utils.toindex(nids)
if ntype == DEFAULT_NTYPE: if ntype == DEFAULT_NTYPE:
ret = np.searchsorted(self._max_node_ids, nids.tonumpy(), side='right') ret = np.searchsorted(
self._max_node_ids, nids.tonumpy(), side="right"
)
else: else:
ret = np.searchsorted(self._typed_max_node_ids[ntype], nids.tonumpy(), side='right') ret = np.searchsorted(
self._typed_max_node_ids[ntype], nids.tonumpy(), side="right"
)
ret = utils.toindex(ret) ret = utils.toindex(ret)
return ret.tousertensor() return ret.tousertensor()
def eid2partid(self, eids, etype=DEFAULT_ETYPE): def eid2partid(self, eids, etype=DEFAULT_ETYPE):
"""From global edge IDs to partition IDs """From global edge IDs to partition IDs"""
"""
eids = utils.toindex(eids) eids = utils.toindex(eids)
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]): if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
ret = np.searchsorted(self._max_edge_ids, eids.tonumpy(), side='right') ret = np.searchsorted(
self._max_edge_ids, eids.tonumpy(), side="right"
)
else: else:
c_etype = self._to_canonical_etype(etype) c_etype = self._to_canonical_etype(etype)
ret = np.searchsorted(self._typed_max_edge_ids[c_etype], eids.tonumpy(), side='right') ret = np.searchsorted(
self._typed_max_edge_ids[c_etype], eids.tonumpy(), side="right"
)
ret = utils.toindex(ret) ret = utils.toindex(ret)
return ret.tousertensor() return ret.tousertensor()
def partid2nids(self, partid, ntype=DEFAULT_NTYPE): def partid2nids(self, partid, ntype=DEFAULT_NTYPE):
"""From partition ID to global node IDs """From partition ID to global node IDs"""
"""
# TODO do we need to cache it? # TODO do we need to cache it?
if ntype == DEFAULT_NTYPE: if ntype == DEFAULT_NTYPE:
start = self._max_node_ids[partid - 1] if partid > 0 else 0 start = self._max_node_ids[partid - 1] if partid > 0 else 0
end = self._max_node_ids[partid] end = self._max_node_ids[partid]
return F.arange(start, end) return F.arange(start, end)
else: else:
start = self._typed_max_node_ids[ntype][partid - 1] if partid > 0 else 0 start = (
self._typed_max_node_ids[ntype][partid - 1] if partid > 0 else 0
)
end = self._typed_max_node_ids[ntype][partid] end = self._typed_max_node_ids[ntype][partid]
return F.arange(start, end) return F.arange(start, end)
def partid2eids(self, partid, etype=DEFAULT_ETYPE): def partid2eids(self, partid, etype=DEFAULT_ETYPE):
"""From partition ID to global edge IDs """From partition ID to global edge IDs"""
"""
# TODO do we need to cache it? # TODO do we need to cache it?
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]): if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
start = self._max_edge_ids[partid - 1] if partid > 0 else 0 start = self._max_edge_ids[partid - 1] if partid > 0 else 0
...@@ -936,33 +1031,39 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -936,33 +1031,39 @@ class RangePartitionBook(GraphPartitionBook):
return F.arange(start, end) return F.arange(start, end)
else: else:
c_etype = self._to_canonical_etype(etype) c_etype = self._to_canonical_etype(etype)
start = self._typed_max_edge_ids[c_etype][partid - 1] if partid > 0 else 0 start = (
self._typed_max_edge_ids[c_etype][partid - 1]
if partid > 0
else 0
)
end = self._typed_max_edge_ids[c_etype][partid] end = self._typed_max_edge_ids[c_etype][partid]
return F.arange(start, end) return F.arange(start, end)
def nid2localnid(self, nids, partid, ntype=DEFAULT_NTYPE): def nid2localnid(self, nids, partid, ntype=DEFAULT_NTYPE):
"""Get local node IDs within the given partition. """Get local node IDs within the given partition."""
"""
if partid != self._partid: if partid != self._partid:
raise RuntimeError('Now RangePartitionBook does not support \ raise RuntimeError(
getting remote tensor of nid2localnid.') "Now RangePartitionBook does not support \
getting remote tensor of nid2localnid."
)
nids = utils.toindex(nids) nids = utils.toindex(nids)
nids = nids.tousertensor() nids = nids.tousertensor()
if ntype == DEFAULT_NTYPE: if ntype == DEFAULT_NTYPE:
start = self._max_node_ids[partid - 1] if partid > 0 else 0 start = self._max_node_ids[partid - 1] if partid > 0 else 0
else: else:
start = self._typed_max_node_ids[ntype][partid - 1] if partid > 0 else 0 start = (
self._typed_max_node_ids[ntype][partid - 1] if partid > 0 else 0
)
return nids - int(start) return nids - int(start)
def eid2localeid(self, eids, partid, etype=DEFAULT_ETYPE): def eid2localeid(self, eids, partid, etype=DEFAULT_ETYPE):
"""Get the local edge IDs within the given partition. """Get the local edge IDs within the given partition."""
"""
if partid != self._partid: if partid != self._partid:
raise RuntimeError('Now RangePartitionBook does not support \ raise RuntimeError(
getting remote tensor of eid2localeid.') "Now RangePartitionBook does not support \
getting remote tensor of eid2localeid."
)
eids = utils.toindex(eids) eids = utils.toindex(eids)
eids = eids.tousertensor() eids = eids.tousertensor()
...@@ -970,26 +1071,26 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -970,26 +1071,26 @@ class RangePartitionBook(GraphPartitionBook):
start = self._max_edge_ids[partid - 1] if partid > 0 else 0 start = self._max_edge_ids[partid - 1] if partid > 0 else 0
else: else:
c_etype = self._to_canonical_etype(etype) c_etype = self._to_canonical_etype(etype)
start = self._typed_max_edge_ids[c_etype][partid - 1] if partid > 0 else 0 start = (
self._typed_max_edge_ids[c_etype][partid - 1]
if partid > 0
else 0
)
return eids - int(start) return eids - int(start)
@property @property
def partid(self): def partid(self):
"""Get the current partition ID. """Get the current partition ID."""
"""
return self._partid return self._partid
@property @property
def ntypes(self): def ntypes(self):
"""Get the list of node types """Get the list of node types"""
"""
return self._ntypes return self._ntypes
@property @property
def etypes(self): def etypes(self):
"""Get the list of edge types """Get the list of edge types"""
"""
return self._etypes return self._etypes
@property @property
...@@ -1023,12 +1124,16 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -1023,12 +1124,16 @@ class RangePartitionBook(GraphPartitionBook):
if ret is None: if ret is None:
raise DGLError('Edge type "{}" does not exist.'.format(etype)) raise DGLError('Edge type "{}" does not exist.'.format(etype))
if len(ret) == 0: if len(ret) == 0:
raise DGLError('Edge type "%s" is ambiguous. Please use canonical edge type ' raise DGLError(
'in the form of (srctype, etype, dsttype)' % etype) 'Edge type "%s" is ambiguous. Please use canonical edge type '
"in the form of (srctype, etype, dsttype)" % etype
)
return ret return ret
NODE_PART_POLICY = 'node'
EDGE_PART_POLICY = 'edge' NODE_PART_POLICY = "node"
EDGE_PART_POLICY = "edge"
class PartitionPolicy(object): class PartitionPolicy(object):
"""This defines a partition policy for a distributed tensor or distributed embedding. """This defines a partition policy for a distributed tensor or distributed embedding.
...@@ -1047,11 +1152,14 @@ class PartitionPolicy(object): ...@@ -1047,11 +1152,14 @@ class PartitionPolicy(object):
partition_book : GraphPartitionBook partition_book : GraphPartitionBook
A graph partition book A graph partition book
""" """
def __init__(self, policy_str, partition_book): def __init__(self, policy_str, partition_book):
splits = policy_str.split(':') splits = policy_str.split(":")
if len(splits) == 1: if len(splits) == 1:
assert policy_str in (EDGE_PART_POLICY, NODE_PART_POLICY), \ assert policy_str in (
'policy_str must contain \'edge\' or \'node\'.' EDGE_PART_POLICY,
NODE_PART_POLICY,
), "policy_str must contain 'edge' or 'node'."
if NODE_PART_POLICY == policy_str: if NODE_PART_POLICY == policy_str:
policy_str = NODE_PART_POLICY + ":" + DEFAULT_NTYPE policy_str = NODE_PART_POLICY + ":" + DEFAULT_NTYPE
else: else:
...@@ -1106,8 +1214,7 @@ class PartitionPolicy(object): ...@@ -1106,8 +1214,7 @@ class PartitionPolicy(object):
return self._partition_book return self._partition_book
def get_data_name(self, name): def get_data_name(self, name):
"""Get HeteroDataName """Get HeteroDataName"""
"""
is_node = NODE_PART_POLICY in self.policy_str is_node = NODE_PART_POLICY in self.policy_str
return HeteroDataName(is_node, self.type_name, name) return HeteroDataName(is_node, self.type_name, name)
...@@ -1125,11 +1232,15 @@ class PartitionPolicy(object): ...@@ -1125,11 +1232,15 @@ class PartitionPolicy(object):
local ID tensor local ID tensor
""" """
if EDGE_PART_POLICY in self.policy_str: if EDGE_PART_POLICY in self.policy_str:
return self._partition_book.eid2localeid(id_tensor, self._part_id, self.type_name) return self._partition_book.eid2localeid(
id_tensor, self._part_id, self.type_name
)
elif NODE_PART_POLICY in self.policy_str: elif NODE_PART_POLICY in self.policy_str:
return self._partition_book.nid2localnid(id_tensor, self._part_id, self.type_name) return self._partition_book.nid2localnid(
id_tensor, self._part_id, self.type_name
)
else: else:
raise RuntimeError('Cannot support policy: %s ' % self.policy_str) raise RuntimeError("Cannot support policy: %s " % self.policy_str)
def to_partid(self, id_tensor): def to_partid(self, id_tensor):
"""Mapping global ID to partition ID. """Mapping global ID to partition ID.
...@@ -1149,7 +1260,7 @@ class PartitionPolicy(object): ...@@ -1149,7 +1260,7 @@ class PartitionPolicy(object):
elif NODE_PART_POLICY in self.policy_str: elif NODE_PART_POLICY in self.policy_str:
return self._partition_book.nid2partid(id_tensor, self.type_name) return self._partition_book.nid2partid(id_tensor, self.type_name)
else: else:
raise RuntimeError('Cannot support policy: %s ' % self.policy_str) raise RuntimeError("Cannot support policy: %s " % self.policy_str)
def get_part_size(self): def get_part_size(self):
"""Get data size of current partition. """Get data size of current partition.
...@@ -1160,11 +1271,15 @@ class PartitionPolicy(object): ...@@ -1160,11 +1271,15 @@ class PartitionPolicy(object):
data size data size
""" """
if EDGE_PART_POLICY in self.policy_str: if EDGE_PART_POLICY in self.policy_str:
return len(self._partition_book.partid2eids(self._part_id, self.type_name)) return len(
self._partition_book.partid2eids(self._part_id, self.type_name)
)
elif NODE_PART_POLICY in self.policy_str: elif NODE_PART_POLICY in self.policy_str:
return len(self._partition_book.partid2nids(self._part_id, self.type_name)) return len(
self._partition_book.partid2nids(self._part_id, self.type_name)
)
else: else:
raise RuntimeError('Cannot support policy: %s ' % self.policy_str) raise RuntimeError("Cannot support policy: %s " % self.policy_str)
def get_size(self): def get_size(self):
"""Get the full size of the data. """Get the full size of the data.
...@@ -1179,24 +1294,29 @@ class PartitionPolicy(object): ...@@ -1179,24 +1294,29 @@ class PartitionPolicy(object):
elif NODE_PART_POLICY in self.policy_str: elif NODE_PART_POLICY in self.policy_str:
return self._partition_book._num_nodes(self.type_name) return self._partition_book._num_nodes(self.type_name)
else: else:
raise RuntimeError('Cannot support policy: %s ' % self.policy_str) raise RuntimeError("Cannot support policy: %s " % self.policy_str)
class NodePartitionPolicy(PartitionPolicy): class NodePartitionPolicy(PartitionPolicy):
'''Partition policy for nodes. """Partition policy for nodes."""
'''
def __init__(self, partition_book, ntype=DEFAULT_NTYPE): def __init__(self, partition_book, ntype=DEFAULT_NTYPE):
super(NodePartitionPolicy, self).__init__( super(NodePartitionPolicy, self).__init__(
NODE_PART_POLICY + ':' + ntype, partition_book) NODE_PART_POLICY + ":" + ntype, partition_book
)
class EdgePartitionPolicy(PartitionPolicy): class EdgePartitionPolicy(PartitionPolicy):
'''Partition policy for edges. """Partition policy for edges."""
'''
def __init__(self, partition_book, etype=DEFAULT_ETYPE): def __init__(self, partition_book, etype=DEFAULT_ETYPE):
super(EdgePartitionPolicy, self).__init__( super(EdgePartitionPolicy, self).__init__(
EDGE_PART_POLICY + ':' + str(etype), partition_book) EDGE_PART_POLICY + ":" + str(etype), partition_book
)
class HeteroDataName(object): class HeteroDataName(object):
''' The data name in a heterogeneous graph. """The data name in a heterogeneous graph.
A unique data name has three components: A unique data name has three components:
* indicate it's node data or edge data. * indicate it's node data or edge data.
...@@ -1211,7 +1331,8 @@ class HeteroDataName(object): ...@@ -1211,7 +1331,8 @@ class HeteroDataName(object):
The type of the node/edge. The type of the node/edge.
data_name : str data_name : str
The name of the data. The name of the data.
''' """
def __init__(self, is_node, entity_type, data_name): def __init__(self, is_node, entity_type, data_name):
self._policy = NODE_PART_POLICY if is_node else EDGE_PART_POLICY self._policy = NODE_PART_POLICY if is_node else EDGE_PART_POLICY
self._entity_type = entity_type self._entity_type = entity_type
...@@ -1219,41 +1340,38 @@ class HeteroDataName(object): ...@@ -1219,41 +1340,38 @@ class HeteroDataName(object):
@property @property
def policy_str(self): def policy_str(self):
''' concatenate policy and entity type into string """concatenate policy and entity type into string"""
''' return self._policy + ":" + str(self.get_type())
return self._policy + ':' + str(self.get_type())
def is_node(self): def is_node(self):
''' Is this the name of node data """Is this the name of node data"""
'''
return NODE_PART_POLICY in self.policy_str return NODE_PART_POLICY in self.policy_str
def is_edge(self): def is_edge(self):
''' Is this the name of edge data """Is this the name of edge data"""
'''
return EDGE_PART_POLICY in self.policy_str return EDGE_PART_POLICY in self.policy_str
def get_type(self): def get_type(self):
''' The type of the node/edge. """The type of the node/edge.
This is only meaningful in a heterogeneous graph. This is only meaningful in a heterogeneous graph.
In homogeneous graph, type is '_N' for a node and '_E' for an edge. In homogeneous graph, type is '_N' for a node and '_E' for an edge.
''' """
return self._entity_type return self._entity_type
def get_name(self): def get_name(self):
''' The name of the data. """The name of the data."""
'''
return self.data_name return self.data_name
def __str__(self): def __str__(self):
''' The full name of the data. """The full name of the data.
The full name is used as the key in the KVStore. The full name is used as the key in the KVStore.
''' """
return self.policy_str + ':' + self.data_name return self.policy_str + ":" + self.data_name
def parse_hetero_data_name(name): def parse_hetero_data_name(name):
'''Parse data name and create HeteroDataName. """Parse data name and create HeteroDataName.
The data name has a specialized format. We can parse the name to determine if The data name has a specialized format. We can parse the name to determine if
it's node data or edge data, node/edge type and its actual name. The data name it's node data or edge data, node/edge type and its actual name. The data name
...@@ -1267,9 +1385,15 @@ def parse_hetero_data_name(name): ...@@ -1267,9 +1385,15 @@ def parse_hetero_data_name(name):
Returns Returns
------- -------
HeteroDataName HeteroDataName
''' """
names = name.split(':') names = name.split(":")
assert len(names) == 3, '{} is not a valid heterograph data name'.format(name) assert len(names) == 3, "{} is not a valid heterograph data name".format(
assert names[0] in (NODE_PART_POLICY, EDGE_PART_POLICY), \ name
'{} is not a valid heterograph data name'.format(name) )
return HeteroDataName(names[0] == NODE_PART_POLICY, _str_to_tuple(names[1]), names[2]) assert names[0] in (
NODE_PART_POLICY,
EDGE_PART_POLICY,
), "{} is not a valid heterograph data name".format(name)
return HeteroDataName(
names[0] == NODE_PART_POLICY, _str_to_tuple(names[1]), names[2]
)
"""A set of graph services of getting subgraphs from DistGraph""" """A set of graph services of getting subgraphs from DistGraph"""
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
from .rpc import Request, Response, send_requests_to_machine, recv_responses from .. import backend as F
from ..sampling import sample_neighbors as local_sample_neighbors from ..base import EID, NID
from ..convert import graph, heterograph
from ..sampling import sample_etype_neighbors as local_sample_etype_neighbors from ..sampling import sample_etype_neighbors as local_sample_etype_neighbors
from ..sampling import sample_neighbors as local_sample_neighbors
from ..subgraph import in_subgraph as local_in_subgraph from ..subgraph import in_subgraph as local_in_subgraph
from .rpc import register_service
from ..convert import graph, heterograph
from ..base import NID, EID
from ..utils import toindex from ..utils import toindex
from .. import backend as F from .rpc import (
Request,
Response,
recv_responses,
register_service,
send_requests_to_machine,
)
__all__ = [ __all__ = [
'sample_neighbors', 'sample_etype_neighbors', "sample_neighbors",
'in_subgraph', 'find_edges' "sample_etype_neighbors",
"in_subgraph",
"find_edges",
] ]
SAMPLING_SERVICE_ID = 6657 SAMPLING_SERVICE_ID = 6657
...@@ -24,6 +32,7 @@ OUTDEGREE_SERVICE_ID = 6660 ...@@ -24,6 +32,7 @@ OUTDEGREE_SERVICE_ID = 6660
INDEGREE_SERVICE_ID = 6661 INDEGREE_SERVICE_ID = 6661
ETYPE_SAMPLING_SERVICE_ID = 6662 ETYPE_SAMPLING_SERVICE_ID = 6662
class SubgraphResponse(Response): class SubgraphResponse(Response):
"""The response for sampling and in_subgraph""" """The response for sampling and in_subgraph"""
...@@ -38,6 +47,7 @@ class SubgraphResponse(Response): ...@@ -38,6 +47,7 @@ class SubgraphResponse(Response):
def __getstate__(self): def __getstate__(self):
return self.global_src, self.global_dst, self.global_eids return self.global_src, self.global_dst, self.global_eids
class FindEdgeResponse(Response): class FindEdgeResponse(Response):
"""The response for sampling and in_subgraph""" """The response for sampling and in_subgraph"""
...@@ -52,8 +62,11 @@ class FindEdgeResponse(Response): ...@@ -52,8 +62,11 @@ class FindEdgeResponse(Response):
def __getstate__(self): def __getstate__(self):
return self.global_src, self.global_dst, self.order_id return self.global_src, self.global_dst, self.order_id
def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace):
""" Sample from local partition. def _sample_neighbors(
local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace
):
"""Sample from local partition.
The input nodes use global IDs. We need to map the global node IDs to local node IDs, The input nodes use global IDs. We need to map the global node IDs to local node IDs,
perform sampling and map the sampled results to the global IDs space again. perform sampling and map the sampled results to the global IDs space again.
...@@ -64,17 +77,35 @@ def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, pr ...@@ -64,17 +77,35 @@ def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, pr
local_ids = F.astype(local_ids, local_g.idtype) local_ids = F.astype(local_ids, local_g.idtype)
# local_ids = self.seed_nodes # local_ids = self.seed_nodes
sampled_graph = local_sample_neighbors( sampled_graph = local_sample_neighbors(
local_g, local_ids, fan_out, edge_dir, prob, replace, _dist_training=True) local_g,
local_ids,
fan_out,
edge_dir,
prob,
replace,
_dist_training=True,
)
global_nid_mapping = local_g.ndata[NID] global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges() src, dst = sampled_graph.edges()
global_src, global_dst = F.gather_row(global_nid_mapping, src), \ global_src, global_dst = F.gather_row(
F.gather_row(global_nid_mapping, dst) global_nid_mapping, src
), F.gather_row(global_nid_mapping, dst)
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID]) global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids return global_src, global_dst, global_eids
def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field,
fan_out, edge_dir, prob, replace, etype_sorted=False): def _sample_etype_neighbors(
""" Sample from local partition. local_g,
partition_book,
seed_nodes,
etype_field,
fan_out,
edge_dir,
prob,
replace,
etype_sorted=False,
):
"""Sample from local partition.
The input nodes use global IDs. We need to map the global node IDs to local node IDs, The input nodes use global IDs. We need to map the global node IDs to local node IDs,
perform sampling and map the sampled results to the global IDs space again. perform sampling and map the sampled results to the global IDs space again.
...@@ -85,18 +116,28 @@ def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field, ...@@ -85,18 +116,28 @@ def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field,
local_ids = F.astype(local_ids, local_g.idtype) local_ids = F.astype(local_ids, local_g.idtype)
sampled_graph = local_sample_etype_neighbors( sampled_graph = local_sample_etype_neighbors(
local_g, local_ids, etype_field, fan_out, edge_dir, prob, replace, local_g,
etype_sorted=etype_sorted, _dist_training=True) local_ids,
etype_field,
fan_out,
edge_dir,
prob,
replace,
etype_sorted=etype_sorted,
_dist_training=True,
)
global_nid_mapping = local_g.ndata[NID] global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges() src, dst = sampled_graph.edges()
global_src, global_dst = F.gather_row(global_nid_mapping, src), \ global_src, global_dst = F.gather_row(
F.gather_row(global_nid_mapping, dst) global_nid_mapping, src
), F.gather_row(global_nid_mapping, dst)
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID]) global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids return global_src, global_dst, global_eids
def _find_edges(local_g, partition_book, seed_edges): def _find_edges(local_g, partition_book, seed_edges):
"""Given an edge ID array, return the source """Given an edge ID array, return the source
and destination node ID array ``s`` and ``d`` in the local partition. and destination node ID array ``s`` and ``d`` in the local partition.
""" """
local_eids = partition_book.eid2localeid(seed_edges, partition_book.partid) local_eids = partition_book.eid2localeid(seed_edges, partition_book.partid)
local_eids = F.astype(local_eids, local_g.idtype) local_eids = F.astype(local_eids, local_g.idtype)
...@@ -106,22 +147,23 @@ def _find_edges(local_g, partition_book, seed_edges): ...@@ -106,22 +147,23 @@ def _find_edges(local_g, partition_book, seed_edges):
global_dst = global_nid_mapping[local_dst] global_dst = global_nid_mapping[local_dst]
return global_src, global_dst return global_src, global_dst
def _in_degrees(local_g, partition_book, n): def _in_degrees(local_g, partition_book, n):
"""Get in-degree of the nodes in the local partition. """Get in-degree of the nodes in the local partition."""
"""
local_nids = partition_book.nid2localnid(n, partition_book.partid) local_nids = partition_book.nid2localnid(n, partition_book.partid)
local_nids = F.astype(local_nids, local_g.idtype) local_nids = F.astype(local_nids, local_g.idtype)
return local_g.in_degrees(local_nids) return local_g.in_degrees(local_nids)
def _out_degrees(local_g, partition_book, n): def _out_degrees(local_g, partition_book, n):
"""Get out-degree of the nodes in the local partition. """Get out-degree of the nodes in the local partition."""
"""
local_nids = partition_book.nid2localnid(n, partition_book.partid) local_nids = partition_book.nid2localnid(n, partition_book.partid)
local_nids = F.astype(local_nids, local_g.idtype) local_nids = F.astype(local_nids, local_g.idtype)
return local_g.out_degrees(local_nids) return local_g.out_degrees(local_nids)
def _in_subgraph(local_g, partition_book, seed_nodes): def _in_subgraph(local_g, partition_book, seed_nodes):
""" Get in subgraph from local partition. """Get in subgraph from local partition.
The input nodes use global IDs. We need to map the global node IDs to local node IDs, The input nodes use global IDs. We need to map the global node IDs to local node IDs,
get in-subgraph and map the sampled results to the global IDs space again. get in-subgraph and map the sampled results to the global IDs space again.
...@@ -142,7 +184,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes): ...@@ -142,7 +184,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
class SamplingRequest(Request): class SamplingRequest(Request):
"""Sampling Request""" """Sampling Request"""
def __init__(self, nodes, fan_out, edge_dir='in', prob=None, replace=False): def __init__(self, nodes, fan_out, edge_dir="in", prob=None, replace=False):
self.seed_nodes = nodes self.seed_nodes = nodes
self.edge_dir = edge_dir self.edge_dir = edge_dir
self.prob = prob self.prob = prob
...@@ -150,25 +192,51 @@ class SamplingRequest(Request): ...@@ -150,25 +192,51 @@ class SamplingRequest(Request):
self.fan_out = fan_out self.fan_out = fan_out
def __setstate__(self, state): def __setstate__(self, state):
self.seed_nodes, self.edge_dir, self.prob, self.replace, self.fan_out = state (
self.seed_nodes,
self.edge_dir,
self.prob,
self.replace,
self.fan_out,
) = state
def __getstate__(self): def __getstate__(self):
return self.seed_nodes, self.edge_dir, self.prob, self.replace, self.fan_out return (
self.seed_nodes,
self.edge_dir,
self.prob,
self.replace,
self.fan_out,
)
def process_request(self, server_state): def process_request(self, server_state):
local_g = server_state.graph local_g = server_state.graph
partition_book = server_state.partition_book partition_book = server_state.partition_book
global_src, global_dst, global_eids = _sample_neighbors(local_g, partition_book, global_src, global_dst, global_eids = _sample_neighbors(
self.seed_nodes, local_g,
self.fan_out, self.edge_dir, partition_book,
self.prob, self.replace) self.seed_nodes,
self.fan_out,
self.edge_dir,
self.prob,
self.replace,
)
return SubgraphResponse(global_src, global_dst, global_eids) return SubgraphResponse(global_src, global_dst, global_eids)
class SamplingRequestEtype(Request): class SamplingRequestEtype(Request):
"""Sampling Request""" """Sampling Request"""
def __init__(self, nodes, etype_field, fan_out, edge_dir='in', def __init__(
prob=None, replace=False, etype_sorted=True): self,
nodes,
etype_field,
fan_out,
edge_dir="in",
prob=None,
replace=False,
etype_sorted=True,
):
self.seed_nodes = nodes self.seed_nodes = nodes
self.edge_dir = edge_dir self.edge_dir = edge_dir
self.prob = prob self.prob = prob
...@@ -178,27 +246,44 @@ class SamplingRequestEtype(Request): ...@@ -178,27 +246,44 @@ class SamplingRequestEtype(Request):
self.etype_sorted = etype_sorted self.etype_sorted = etype_sorted
def __setstate__(self, state): def __setstate__(self, state):
self.seed_nodes, self.edge_dir, self.prob, self.replace, \ (
self.fan_out, self.etype_field, self.etype_sorted = state self.seed_nodes,
self.edge_dir,
self.prob,
self.replace,
self.fan_out,
self.etype_field,
self.etype_sorted,
) = state
def __getstate__(self): def __getstate__(self):
return self.seed_nodes, self.edge_dir, self.prob, self.replace, \ return (
self.fan_out, self.etype_field, self.etype_sorted self.seed_nodes,
self.edge_dir,
self.prob,
self.replace,
self.fan_out,
self.etype_field,
self.etype_sorted,
)
def process_request(self, server_state): def process_request(self, server_state):
local_g = server_state.graph local_g = server_state.graph
partition_book = server_state.partition_book partition_book = server_state.partition_book
global_src, global_dst, global_eids = _sample_etype_neighbors(local_g, global_src, global_dst, global_eids = _sample_etype_neighbors(
partition_book, local_g,
self.seed_nodes, partition_book,
self.etype_field, self.seed_nodes,
self.fan_out, self.etype_field,
self.edge_dir, self.fan_out,
self.prob, self.edge_dir,
self.replace, self.prob,
self.etype_sorted) self.replace,
self.etype_sorted,
)
return SubgraphResponse(global_src, global_dst, global_eids) return SubgraphResponse(global_src, global_dst, global_eids)
class EdgesRequest(Request): class EdgesRequest(Request):
"""Edges Request""" """Edges Request"""
...@@ -215,10 +300,13 @@ class EdgesRequest(Request): ...@@ -215,10 +300,13 @@ class EdgesRequest(Request):
def process_request(self, server_state): def process_request(self, server_state):
local_g = server_state.graph local_g = server_state.graph
partition_book = server_state.partition_book partition_book = server_state.partition_book
global_src, global_dst = _find_edges(local_g, partition_book, self.edge_ids) global_src, global_dst = _find_edges(
local_g, partition_book, self.edge_ids
)
return FindEdgeResponse(global_src, global_dst, self.order_id) return FindEdgeResponse(global_src, global_dst, self.order_id)
class InDegreeRequest(Request): class InDegreeRequest(Request):
"""In-degree Request""" """In-degree Request"""
...@@ -239,6 +327,7 @@ class InDegreeRequest(Request): ...@@ -239,6 +327,7 @@ class InDegreeRequest(Request):
return InDegreeResponse(deg, self.order_id) return InDegreeResponse(deg, self.order_id)
class InDegreeResponse(Response): class InDegreeResponse(Response):
"""The response for in-degree""" """The response for in-degree"""
...@@ -252,6 +341,7 @@ class InDegreeResponse(Response): ...@@ -252,6 +341,7 @@ class InDegreeResponse(Response):
def __getstate__(self): def __getstate__(self):
return self.val, self.order_id return self.val, self.order_id
class OutDegreeRequest(Request): class OutDegreeRequest(Request):
"""Out-degree Request""" """Out-degree Request"""
...@@ -272,6 +362,7 @@ class OutDegreeRequest(Request): ...@@ -272,6 +362,7 @@ class OutDegreeRequest(Request):
return OutDegreeResponse(deg, self.order_id) return OutDegreeResponse(deg, self.order_id)
class OutDegreeResponse(Response): class OutDegreeResponse(Response):
"""The response for out-degree""" """The response for out-degree"""
...@@ -285,6 +376,7 @@ class OutDegreeResponse(Response): ...@@ -285,6 +376,7 @@ class OutDegreeResponse(Response):
def __getstate__(self): def __getstate__(self):
return self.val, self.order_id return self.val, self.order_id
class InSubgraphRequest(Request): class InSubgraphRequest(Request):
"""InSubgraph Request""" """InSubgraph Request"""
...@@ -300,8 +392,9 @@ class InSubgraphRequest(Request): ...@@ -300,8 +392,9 @@ class InSubgraphRequest(Request):
def process_request(self, server_state): def process_request(self, server_state):
local_g = server_state.graph local_g = server_state.graph
partition_book = server_state.partition_book partition_book = server_state.partition_book
global_src, global_dst, global_eids = _in_subgraph(local_g, partition_book, global_src, global_dst, global_eids = _in_subgraph(
self.seed_nodes) local_g, partition_book, self.seed_nodes
)
return SubgraphResponse(global_src, global_dst, global_eids) return SubgraphResponse(global_src, global_dst, global_eids)
...@@ -326,10 +419,14 @@ def merge_graphs(res_list, num_nodes): ...@@ -326,10 +419,14 @@ def merge_graphs(res_list, num_nodes):
g.edata[EID] = eid_tensor g.edata[EID] = eid_tensor
return g return g
LocalSampledGraph = namedtuple('LocalSampledGraph', 'global_src global_dst global_eids')
LocalSampledGraph = namedtuple(
"LocalSampledGraph", "global_src global_dst global_eids"
)
def _distributed_access(g, nodes, issue_remote_req, local_access): def _distributed_access(g, nodes, issue_remote_req, local_access):
'''A routine that fetches local neighborhood of nodes from the distributed graph. """A routine that fetches local neighborhood of nodes from the distributed graph.
The local neighborhood of some nodes are stored in the local machine and the other The local neighborhood of some nodes are stored in the local machine and the other
nodes have their neighborhood on remote machines. This code will issue remote nodes have their neighborhood on remote machines. This code will issue remote
...@@ -352,7 +449,7 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): ...@@ -352,7 +449,7 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
------- -------
DGLHeteroGraph DGLHeteroGraph
The subgraph that contains the neighborhoods of all input nodes. The subgraph that contains the neighborhoods of all input nodes.
''' """
req_list = [] req_list = []
partition_book = g.get_partition_book() partition_book = g.get_partition_book()
nodes = toindex(nodes).tousertensor() nodes = toindex(nodes).tousertensor()
...@@ -379,7 +476,9 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): ...@@ -379,7 +476,9 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
# sample neighbors for the nodes in the local partition. # sample neighbors for the nodes in the local partition.
res_list = [] res_list = []
if local_nids is not None: if local_nids is not None:
src, dst, eids = local_access(g.local_partition, partition_book, local_nids) src, dst, eids = local_access(
g.local_partition, partition_book, local_nids
)
res_list.append(LocalSampledGraph(src, dst, eids)) res_list.append(LocalSampledGraph(src, dst, eids))
# receive responses from remote machines. # receive responses from remote machines.
...@@ -390,13 +489,18 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): ...@@ -390,13 +489,18 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
sampled_graph = merge_graphs(res_list, g.number_of_nodes()) sampled_graph = merge_graphs(res_list, g.number_of_nodes())
return sampled_graph return sampled_graph
def _frontier_to_heterogeneous_graph(g, frontier, gpb): def _frontier_to_heterogeneous_graph(g, frontier, gpb):
# We need to handle empty frontiers correctly. # We need to handle empty frontiers correctly.
if frontier.number_of_edges() == 0: if frontier.number_of_edges() == 0:
data_dict = {etype: (np.zeros(0), np.zeros(0)) for etype in g.canonical_etypes} data_dict = {
return heterograph(data_dict, etype: (np.zeros(0), np.zeros(0)) for etype in g.canonical_etypes
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}, }
idtype=g.idtype) return heterograph(
data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype,
)
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID]) etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
src, dst = frontier.edges() src, dst = frontier.edges()
...@@ -413,19 +517,32 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb): ...@@ -413,19 +517,32 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
canonical_etype = g.canonical_etypes[etid] canonical_etype = g.canonical_etypes[etid]
type_idx = etype_ids == etid type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0: if F.sum(type_idx, 0) > 0:
data_dict[canonical_etype] = (F.boolean_mask(src, type_idx), \ data_dict[canonical_etype] = (
F.boolean_mask(dst, type_idx)) F.boolean_mask(src, type_idx),
F.boolean_mask(dst, type_idx),
)
edge_ids[etype] = F.boolean_mask(eid, type_idx) edge_ids[etype] = F.boolean_mask(eid, type_idx)
hg = heterograph(data_dict, hg = heterograph(
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}, data_dict,
idtype=g.idtype) {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype,
)
for etype in edge_ids: for etype in edge_ids:
hg.edges[etype].data[EID] = edge_ids[etype] hg.edges[etype].data[EID] = edge_ids[etype]
return hg return hg
def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in',
prob=None, replace=False, etype_sorted=True): def sample_etype_neighbors(
g,
nodes,
etype_field,
fanout,
edge_dir="in",
prob=None,
replace=False,
etype_sorted=True,
):
"""Sample from the neighbors of the given nodes from a distributed graph. """Sample from the neighbors of the given nodes from a distributed graph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
...@@ -495,28 +612,50 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', ...@@ -495,28 +612,50 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in',
if isinstance(nodes, dict): if isinstance(nodes, dict):
homo_nids = [] homo_nids = []
for ntype in nodes.keys(): for ntype in nodes.keys():
assert ntype in g.ntypes, \ assert (
'The sampled node type {} does not exist in the input graph'.format(ntype) ntype in g.ntypes
), "The sampled node type {} does not exist in the input graph".format(
ntype
)
if F.is_tensor(nodes[ntype]): if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype] typed_nodes = nodes[ntype]
else: else:
typed_nodes = toindex(nodes[ntype]).tousertensor() typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype)) homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
nodes = F.cat(homo_nids, 0) nodes = F.cat(homo_nids, 0)
def issue_remote_req(node_ids): def issue_remote_req(node_ids):
return SamplingRequestEtype(node_ids, etype_field, fanout, edge_dir=edge_dir, return SamplingRequestEtype(
prob=prob, replace=replace, etype_sorted=etype_sorted) node_ids,
etype_field,
fanout,
edge_dir=edge_dir,
prob=prob,
replace=replace,
etype_sorted=etype_sorted,
)
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
return _sample_etype_neighbors(local_g, partition_book, local_nids, return _sample_etype_neighbors(
etype_field, fanout, edge_dir, prob, replace, local_g,
etype_sorted=etype_sorted) partition_book,
local_nids,
etype_field,
fanout,
edge_dir,
prob,
replace,
etype_sorted=etype_sorted,
)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access) frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if not gpb.is_homogeneous: if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb) return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else: else:
return frontier return frontier
def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
"""Sample from the neighbors of the given nodes from a distributed graph. """Sample from the neighbors of the given nodes from a distributed graph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
...@@ -570,7 +709,9 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -570,7 +709,9 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
assert isinstance(nodes, dict) assert isinstance(nodes, dict)
homo_nids = [] homo_nids = []
for ntype in nodes: for ntype in nodes:
assert ntype in g.ntypes, 'The sampled node type does not exist in the input graph' assert (
ntype in g.ntypes
), "The sampled node type does not exist in the input graph"
if F.is_tensor(nodes[ntype]): if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype] typed_nodes = nodes[ntype]
else: else:
...@@ -582,17 +723,22 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -582,17 +723,22 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
nodes = list(nodes.values())[0] nodes = list(nodes.values())[0]
def issue_remote_req(node_ids): def issue_remote_req(node_ids):
return SamplingRequest(node_ids, fanout, edge_dir=edge_dir, return SamplingRequest(
prob=prob, replace=replace) node_ids, fanout, edge_dir=edge_dir, prob=prob, replace=replace
)
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
return _sample_neighbors(local_g, partition_book, local_nids, return _sample_neighbors(
fanout, edge_dir, prob, replace) local_g, partition_book, local_nids, fanout, edge_dir, prob, replace
)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access) frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if not gpb.is_homogeneous: if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb) return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else: else:
return frontier return frontier
def _distributed_edge_access(g, edges, issue_remote_req, local_access): def _distributed_edge_access(g, edges, issue_remote_req, local_access):
"""A routine that fetches local edges from distributed graph. """A routine that fetches local edges from distributed graph.
...@@ -626,7 +772,7 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access): ...@@ -626,7 +772,7 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access):
local_eids = None local_eids = None
reorder_idx = [] reorder_idx = []
for pid in range(partition_book.num_partitions()): for pid in range(partition_book.num_partitions()):
mask = (partition_id == pid) mask = partition_id == pid
edge_id = F.boolean_mask(edges, mask) edge_id = F.boolean_mask(edges, mask)
reorder_idx.append(F.nonzero_1d(mask)) reorder_idx.append(F.nonzero_1d(mask))
if pid == partition_book.partid and g.local_partition is not None: if pid == partition_book.partid and g.local_partition is not None:
...@@ -646,8 +792,12 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access): ...@@ -646,8 +792,12 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access):
dst_ids = F.zeros_like(edges) dst_ids = F.zeros_like(edges)
if local_eids is not None: if local_eids is not None:
src, dst = local_access(g.local_partition, partition_book, local_eids) src, dst = local_access(g.local_partition, partition_book, local_eids)
src_ids = F.scatter_row(src_ids, reorder_idx[partition_book.partid], src) src_ids = F.scatter_row(
dst_ids = F.scatter_row(dst_ids, reorder_idx[partition_book.partid], dst) src_ids, reorder_idx[partition_book.partid], src
)
dst_ids = F.scatter_row(
dst_ids, reorder_idx[partition_book.partid], dst
)
# receive responses from remote machines. # receive responses from remote machines.
if msgseq2pos is not None: if msgseq2pos is not None:
...@@ -659,8 +809,9 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access): ...@@ -659,8 +809,9 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access):
dst_ids = F.scatter_row(dst_ids, reorder_idx[result.order_id], dst) dst_ids = F.scatter_row(dst_ids, reorder_idx[result.order_id], dst)
return src_ids, dst_ids return src_ids, dst_ids
def find_edges(g, edge_ids): def find_edges(g, edge_ids):
""" Given an edge ID array, return the source and destination """Given an edge ID array, return the source and destination
node ID array ``s`` and ``d`` from a distributed graph. node ID array ``s`` and ``d`` from a distributed graph.
``s[i]`` and ``d[i]`` are source and destination node ID for ``s[i]`` and ``d[i]`` are source and destination node ID for
edge ``eid[i]``. edge ``eid[i]``.
...@@ -679,12 +830,16 @@ def find_edges(g, edge_ids): ...@@ -679,12 +830,16 @@ def find_edges(g, edge_ids):
tensor tensor
The destination node ID array. The destination node ID array.
""" """
def issue_remote_req(edge_ids, order_id): def issue_remote_req(edge_ids, order_id):
return EdgesRequest(edge_ids, order_id) return EdgesRequest(edge_ids, order_id)
def local_access(local_g, partition_book, edge_ids): def local_access(local_g, partition_book, edge_ids):
return _find_edges(local_g, partition_book, edge_ids) return _find_edges(local_g, partition_book, edge_ids)
return _distributed_edge_access(g, edge_ids, issue_remote_req, local_access) return _distributed_edge_access(g, edge_ids, issue_remote_req, local_access)
def in_subgraph(g, nodes): def in_subgraph(g, nodes):
"""Return the subgraph induced on the inbound edges of the given nodes. """Return the subgraph induced on the inbound edges of the given nodes.
...@@ -713,14 +868,20 @@ def in_subgraph(g, nodes): ...@@ -713,14 +868,20 @@ def in_subgraph(g, nodes):
edge ID via ``dgl.EID`` edge features of the subgraph. edge ID via ``dgl.EID`` edge features of the subgraph.
""" """
if isinstance(nodes, dict): if isinstance(nodes, dict):
assert len(nodes) == 1, 'The distributed in_subgraph only supports one node type for now.' assert (
len(nodes) == 1
), "The distributed in_subgraph only supports one node type for now."
nodes = list(nodes.values())[0] nodes = list(nodes.values())[0]
def issue_remote_req(node_ids): def issue_remote_req(node_ids):
return InSubgraphRequest(node_ids) return InSubgraphRequest(node_ids)
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
return _in_subgraph(local_g, partition_book, local_nids) return _in_subgraph(local_g, partition_book, local_nids)
return _distributed_access(g, nodes, issue_remote_req, local_access) return _distributed_access(g, nodes, issue_remote_req, local_access)
def _distributed_get_node_property(g, n, issue_remote_req, local_access): def _distributed_get_node_property(g, n, issue_remote_req, local_access):
req_list = [] req_list = []
partition_book = g.get_partition_book() partition_book = g.get_partition_book()
...@@ -729,7 +890,7 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access): ...@@ -729,7 +890,7 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access):
local_nids = None local_nids = None
reorder_idx = [] reorder_idx = []
for pid in range(partition_book.num_partitions()): for pid in range(partition_book.num_partitions()):
mask = (partition_id == pid) mask = partition_id == pid
nid = F.boolean_mask(n, mask) nid = F.boolean_mask(n, mask)
reorder_idx.append(F.nonzero_1d(mask)) reorder_idx.append(F.nonzero_1d(mask))
if pid == partition_book.partid and g.local_partition is not None: if pid == partition_book.partid and g.local_partition is not None:
...@@ -751,7 +912,9 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access): ...@@ -751,7 +912,9 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access):
shape = list(F.shape(local_vals)) shape = list(F.shape(local_vals))
shape[0] = len(n) shape[0] = len(n)
vals = F.zeros(shape, F.dtype(local_vals), F.cpu()) vals = F.zeros(shape, F.dtype(local_vals), F.cpu())
vals = F.scatter_row(vals, reorder_idx[partition_book.partid], local_vals) vals = F.scatter_row(
vals, reorder_idx[partition_book.partid], local_vals
)
# receive responses from remote machines. # receive responses from remote machines.
if msgseq2pos is not None: if msgseq2pos is not None:
...@@ -765,27 +928,36 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access): ...@@ -765,27 +928,36 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access):
vals = F.scatter_row(vals, reorder_idx[result.order_id], val) vals = F.scatter_row(vals, reorder_idx[result.order_id], val)
return vals return vals
def in_degrees(g, v): def in_degrees(g, v):
'''Get in-degrees """Get in-degrees"""
'''
def issue_remote_req(v, order_id): def issue_remote_req(v, order_id):
return InDegreeRequest(v, order_id) return InDegreeRequest(v, order_id)
def local_access(local_g, partition_book, v): def local_access(local_g, partition_book, v):
return _in_degrees(local_g, partition_book, v) return _in_degrees(local_g, partition_book, v)
return _distributed_get_node_property(g, v, issue_remote_req, local_access) return _distributed_get_node_property(g, v, issue_remote_req, local_access)
def out_degrees(g, u): def out_degrees(g, u):
'''Get out-degrees """Get out-degrees"""
'''
def issue_remote_req(u, order_id): def issue_remote_req(u, order_id):
return OutDegreeRequest(u, order_id) return OutDegreeRequest(u, order_id)
def local_access(local_g, partition_book, u): def local_access(local_g, partition_book, u):
return _out_degrees(local_g, partition_book, u) return _out_degrees(local_g, partition_book, u)
return _distributed_get_node_property(g, u, issue_remote_req, local_access) return _distributed_get_node_property(g, u, issue_remote_req, local_access)
register_service(SAMPLING_SERVICE_ID, SamplingRequest, SubgraphResponse) register_service(SAMPLING_SERVICE_ID, SamplingRequest, SubgraphResponse)
register_service(EDGES_SERVICE_ID, EdgesRequest, FindEdgeResponse) register_service(EDGES_SERVICE_ID, EdgesRequest, FindEdgeResponse)
register_service(INSUBGRAPH_SERVICE_ID, InSubgraphRequest, SubgraphResponse) register_service(INSUBGRAPH_SERVICE_ID, InSubgraphRequest, SubgraphResponse)
register_service(OUTDEGREE_SERVICE_ID, OutDegreeRequest, OutDegreeResponse) register_service(OUTDEGREE_SERVICE_ID, OutDegreeRequest, OutDegreeResponse)
register_service(INDEGREE_SERVICE_ID, InDegreeRequest, InDegreeResponse) register_service(INDEGREE_SERVICE_ID, InDegreeRequest, InDegreeResponse)
register_service(ETYPE_SAMPLING_SERVICE_ID, SamplingRequestEtype, SubgraphResponse) register_service(
ETYPE_SAMPLING_SERVICE_ID, SamplingRequestEtype, SubgraphResponse
)
"""dgl distributed.optims.""" """dgl distributed.optims."""
import importlib import importlib
import sys
import os import os
import sys
from ...backend import backend_name from ...backend import backend_name
from ...utils import expand_as_pair from ...utils import expand_as_pair
def _load_backend(mod_name): def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__) mod = importlib.import_module(".%s" % mod_name, __name__)
thismod = sys.modules[__name__] thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items(): for api, obj in mod.__dict__.items():
setattr(thismod, api, obj) setattr(thismod, api, obj)
_load_backend(backend_name) _load_backend(backend_name)
"""dgl distributed.optims.""" """dgl distributed.optims."""
import importlib import importlib
import sys
import os import os
import sys
from ...backend import backend_name from ...backend import backend_name
from ...utils import expand_as_pair from ...utils import expand_as_pair
def _load_backend(mod_name): def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__) mod = importlib.import_module(".%s" % mod_name, __name__)
thismod = sys.modules[__name__] thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items(): for api, obj in mod.__dict__.items():
setattr(thismod, api, obj) setattr(thismod, api, obj)
_load_backend(backend_name) _load_backend(backend_name)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import torch as th import torch as th
import torch.distributed as dist import torch.distributed as dist
def alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list): def alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):
"""Each process scatters list of input tensors to all processes in a cluster """Each process scatters list of input tensors to all processes in a cluster
and return gathered list of tensors in output list. The tensors should have the same shape. and return gathered list of tensors in output list. The tensors should have the same shape.
...@@ -18,9 +19,14 @@ def alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list): ...@@ -18,9 +19,14 @@ def alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):
input_tensor_list : List of tensor input_tensor_list : List of tensor
The tensors to exchange The tensors to exchange
""" """
input_tensor_list = [tensor.to(th.device('cpu')) for tensor in input_tensor_list] input_tensor_list = [
tensor.to(th.device("cpu")) for tensor in input_tensor_list
]
for i in range(world_size): for i in range(world_size):
dist.scatter(output_tensor_list[i], input_tensor_list if i == rank else [], src=i) dist.scatter(
output_tensor_list[i], input_tensor_list if i == rank else [], src=i
)
def alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list): def alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list):
"""Each process scatters list of input tensors to all processes in a cluster """Each process scatters list of input tensors to all processes in a cluster
...@@ -42,9 +48,11 @@ def alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list): ...@@ -42,9 +48,11 @@ def alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list):
senders = [] senders = []
for i in range(world_size): for i in range(world_size):
if i == rank: if i == rank:
output_tensor_list[i] = input_tensor_list[i].to(th.device('cpu')) output_tensor_list[i] = input_tensor_list[i].to(th.device("cpu"))
else: else:
sender = dist.isend(input_tensor_list[i].to(th.device('cpu')), dst=i) sender = dist.isend(
input_tensor_list[i].to(th.device("cpu")), dst=i
)
senders.append(sender) senders.append(sender)
for i in range(world_size): for i in range(world_size):
......
...@@ -5,6 +5,7 @@ some work as trainers. ...@@ -5,6 +5,7 @@ some work as trainers.
""" """
import os import os
import numpy as np import numpy as np
from . import rpc from . import rpc
...@@ -12,10 +13,12 @@ from . import rpc ...@@ -12,10 +13,12 @@ from . import rpc
REGISTER_ROLE = 700001 REGISTER_ROLE = 700001
REG_ROLE_MSG = "Register_Role" REG_ROLE_MSG = "Register_Role"
class RegisterRoleResponse(rpc.Response): class RegisterRoleResponse(rpc.Response):
"""Send a confirmation signal (just a short string message) """Send a confirmation signal (just a short string message)
of RegisterRoleRequest to client. of RegisterRoleRequest to client.
""" """
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
...@@ -25,6 +28,7 @@ class RegisterRoleResponse(rpc.Response): ...@@ -25,6 +28,7 @@ class RegisterRoleResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self.msg = state self.msg = state
class RegisterRoleRequest(rpc.Request): class RegisterRoleRequest(rpc.Request):
"""Send client id and role to server """Send client id and role to server
...@@ -35,6 +39,7 @@ class RegisterRoleRequest(rpc.Request): ...@@ -35,6 +39,7 @@ class RegisterRoleRequest(rpc.Request):
role : str role : str
role of client role of client
""" """
def __init__(self, client_id, machine_id, role): def __init__(self, client_id, machine_id, role):
self.client_id = client_id self.client_id = client_id
self.machine_id = machine_id self.machine_id = machine_id
...@@ -53,7 +58,9 @@ class RegisterRoleRequest(rpc.Request): ...@@ -53,7 +58,9 @@ class RegisterRoleRequest(rpc.Request):
if self.role not in role: if self.role not in role:
role[self.role] = set() role[self.role] = set()
if kv_store is not None: if kv_store is not None:
barrier_count = kv_store.barrier_count.setdefault(self.group_id, {}) barrier_count = kv_store.barrier_count.setdefault(
self.group_id, {}
)
barrier_count[self.role] = 0 barrier_count[self.role] = 0
role[self.role].add((self.client_id, self.machine_id)) role[self.role].add((self.client_id, self.machine_id))
total_count = 0 total_count = 0
...@@ -67,11 +74,14 @@ class RegisterRoleRequest(rpc.Request): ...@@ -67,11 +74,14 @@ class RegisterRoleRequest(rpc.Request):
return res_list return res_list
return None return None
GET_ROLE = 700002 GET_ROLE = 700002
GET_ROLE_MSG = "Get_Role" GET_ROLE_MSG = "Get_Role"
class GetRoleResponse(rpc.Response): class GetRoleResponse(rpc.Response):
"""Send the roles of all client processes""" """Send the roles of all client processes"""
def __init__(self, role): def __init__(self, role):
self.role = role self.role = role
self.msg = GET_ROLE_MSG self.msg = GET_ROLE_MSG
...@@ -82,8 +92,10 @@ class GetRoleResponse(rpc.Response): ...@@ -82,8 +92,10 @@ class GetRoleResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self.role, self.msg = state self.role, self.msg = state
class GetRoleRequest(rpc.Request): class GetRoleRequest(rpc.Request):
"""Send a request to get the roles of all client processes.""" """Send a request to get the roles of all client processes."""
def __init__(self): def __init__(self):
self.msg = GET_ROLE_MSG self.msg = GET_ROLE_MSG
self.group_id = rpc.get_group_id() self.group_id = rpc.get_group_id()
...@@ -97,6 +109,7 @@ class GetRoleRequest(rpc.Request): ...@@ -97,6 +109,7 @@ class GetRoleRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
return GetRoleResponse(server_state.roles[self.group_id]) return GetRoleResponse(server_state.roles[self.group_id])
# The key is role, the value is a dict of mapping RPC rank to a rank within the role. # The key is role, the value is a dict of mapping RPC rank to a rank within the role.
PER_ROLE_RANK = {} PER_ROLE_RANK = {}
...@@ -109,6 +122,7 @@ CUR_ROLE = None ...@@ -109,6 +122,7 @@ CUR_ROLE = None
IS_STANDALONE = False IS_STANDALONE = False
def init_role(role): def init_role(role):
"""Initialize the role of the current process. """Initialize the role of the current process.
...@@ -128,10 +142,10 @@ def init_role(role): ...@@ -128,10 +142,10 @@ def init_role(role):
global GLOBAL_RANK global GLOBAL_RANK
global IS_STANDALONE global IS_STANDALONE
if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone': if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone":
if role == 'default': if role == "default":
GLOBAL_RANK[0] = 0 GLOBAL_RANK[0] = 0
PER_ROLE_RANK['default'] = {0:0} PER_ROLE_RANK["default"] = {0: 0}
IS_STANDALONE = True IS_STANDALONE = True
return return
...@@ -160,7 +174,7 @@ def init_role(role): ...@@ -160,7 +174,7 @@ def init_role(role):
global_rank = 0 global_rank = 0
# We want to ensure that the global rank of the trainer process starts from 0. # We want to ensure that the global rank of the trainer process starts from 0.
role_names = ['default'] role_names = ["default"]
for role_name in response.role: for role_name in response.role:
if role_name not in role_names: if role_name not in role_names:
role_names.append(role_name) role_names.append(role_name)
...@@ -185,6 +199,7 @@ def init_role(role): ...@@ -185,6 +199,7 @@ def init_role(role):
PER_ROLE_RANK[role_name][client_id] = per_role_rank PER_ROLE_RANK[role_name][client_id] = per_role_rank
per_role_rank += 1 per_role_rank += 1
def get_global_rank(): def get_global_rank():
"""Get the global rank """Get the global rank
...@@ -196,6 +211,7 @@ def get_global_rank(): ...@@ -196,6 +211,7 @@ def get_global_rank():
else: else:
return GLOBAL_RANK[rpc.get_rank()] return GLOBAL_RANK[rpc.get_rank()]
def get_rank(role): def get_rank(role):
"""Get the role-specific rank""" """Get the role-specific rank"""
if IS_STANDALONE: if IS_STANDALONE:
...@@ -203,25 +219,29 @@ def get_rank(role): ...@@ -203,25 +219,29 @@ def get_rank(role):
else: else:
return PER_ROLE_RANK[role][rpc.get_rank()] return PER_ROLE_RANK[role][rpc.get_rank()]
def get_trainer_rank(): def get_trainer_rank():
"""Get the rank of the current trainer process. """Get the rank of the current trainer process.
This function can only be called in the trainer process. It will result in This function can only be called in the trainer process. It will result in
an error if it's called in the process of other roles. an error if it's called in the process of other roles.
""" """
assert CUR_ROLE == 'default' assert CUR_ROLE == "default"
if IS_STANDALONE: if IS_STANDALONE:
return 0 return 0
else: else:
return PER_ROLE_RANK['default'][rpc.get_rank()] return PER_ROLE_RANK["default"][rpc.get_rank()]
def get_role(): def get_role():
"""Get the role of the current process""" """Get the role of the current process"""
return CUR_ROLE return CUR_ROLE
def get_num_trainers(): def get_num_trainers():
"""Get the number of trainer processes""" """Get the number of trainer processes"""
return len(PER_ROLE_RANK['default']) return len(PER_ROLE_RANK["default"])
rpc.register_service(REGISTER_ROLE, RegisterRoleRequest, RegisterRoleResponse) rpc.register_service(REGISTER_ROLE, RegisterRoleRequest, RegisterRoleResponse)
rpc.register_service(GET_ROLE, GetRoleRequest, GetRoleResponse) rpc.register_service(GET_ROLE, GetRoleRequest, GetRoleResponse)
"""RPC components. They are typically functions or utilities used by both """RPC components. They are typically functions or utilities used by both
server and clients.""" server and clients."""
import os
import abc import abc
import os
import pickle import pickle
import random import random
import numpy as np
from .constants import SERVER_EXIT, SERVER_KEEP_ALIVE import numpy as np
from .._ffi.object import register_object, ObjectBase from .. import backend as F
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .._ffi.object import ObjectBase, register_object
from ..base import DGLError from ..base import DGLError
from .. import backend as F from .constants import SERVER_EXIT, SERVER_KEEP_ALIVE
__all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \ __all__ = [
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \ "set_rank",
'wait_for_senders', 'connect_receiver', 'read_ip_config', 'get_group_id', \ "get_rank",
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \ "Request",
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \ "Response",
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', 'DistConnectError', \ "register_service",
'get_num_client', 'set_num_client', 'client_barrier', 'copy_data_to_shared_memory'] "create_sender",
"create_receiver",
"finalize_sender",
"finalize_receiver",
"wait_for_senders",
"connect_receiver",
"read_ip_config",
"get_group_id",
"get_num_machines",
"set_num_machines",
"get_machine_id",
"set_machine_id",
"send_request",
"recv_request",
"send_response",
"recv_response",
"remote_call",
"send_request_to_machine",
"remote_call_to_machine",
"fast_pull",
"DistConnectError",
"get_num_client",
"set_num_client",
"client_barrier",
"copy_data_to_shared_memory",
]
REQUEST_CLASS_TO_SERVICE_ID = {} REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {} RESPONSE_CLASS_TO_SERVICE_ID = {}
...@@ -27,6 +52,7 @@ SERVICE_ID_TO_PROPERTY = {} ...@@ -27,6 +52,7 @@ SERVICE_ID_TO_PROPERTY = {}
DEFUALT_PORT = 30050 DEFUALT_PORT = 30050
def read_ip_config(filename, num_servers): def read_ip_config(filename, num_servers):
"""Read network configuration information of server from file. """Read network configuration information of server from file.
...@@ -75,13 +101,15 @@ def read_ip_config(filename, num_servers): ...@@ -75,13 +101,15 @@ def read_ip_config(filename, num_servers):
6:[3, '172.31.30.180', 30050, 2], 6:[3, '172.31.30.180', 30050, 2],
7:[3, '172.31.30.180', 30051, 2]} 7:[3, '172.31.30.180', 30051, 2]}
""" """
assert len(filename) > 0, 'filename cannot be empty.' assert len(filename) > 0, "filename cannot be empty."
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers assert num_servers > 0, (
"num_servers (%d) must be a positive number." % num_servers
)
server_namebook = {} server_namebook = {}
try: try:
server_id = 0 server_id = 0
machine_id = 0 machine_id = 0
lines = [line.rstrip('\n') for line in open(filename)] lines = [line.rstrip("\n") for line in open(filename)]
for line in lines: for line in lines:
result = line.split() result = line.split()
if len(result) == 2: if len(result) == 2:
...@@ -89,21 +117,27 @@ def read_ip_config(filename, num_servers): ...@@ -89,21 +117,27 @@ def read_ip_config(filename, num_servers):
elif len(result) == 1: elif len(result) == 1:
port = DEFUALT_PORT port = DEFUALT_PORT
else: else:
raise RuntimeError('length of result can only be 1 or 2.') raise RuntimeError("length of result can only be 1 or 2.")
ip_addr = result[0] ip_addr = result[0]
for s_count in range(num_servers): for s_count in range(num_servers):
server_namebook[server_id] = [machine_id, ip_addr, port+s_count, num_servers] server_namebook[server_id] = [
machine_id,
ip_addr,
port + s_count,
num_servers,
]
server_id += 1 server_id += 1
machine_id += 1 machine_id += 1
except RuntimeError: except RuntimeError:
print("Error: data format on each line should be: [ip] [port]") print("Error: data format on each line should be: [ip] [port]")
return server_namebook return server_namebook
def reset(): def reset():
"""Reset the rpc context """Reset the rpc context"""
"""
_CAPI_DGLRPCReset() _CAPI_DGLRPCReset()
def create_sender(max_queue_size, net_type): def create_sender(max_queue_size, net_type):
"""Create rpc sender of this process. """Create rpc sender of this process.
...@@ -114,9 +148,10 @@ def create_sender(max_queue_size, net_type): ...@@ -114,9 +148,10 @@ def create_sender(max_queue_size, net_type):
net_type : str net_type : str
Networking type. Current options are: 'socket', 'tensorpipe'. Networking type. Current options are: 'socket', 'tensorpipe'.
""" """
max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0')) max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0"))
_CAPI_DGLRPCCreateSender(int(max_queue_size), net_type, max_thread_count) _CAPI_DGLRPCCreateSender(int(max_queue_size), net_type, max_thread_count)
def create_receiver(max_queue_size, net_type): def create_receiver(max_queue_size, net_type):
"""Create rpc receiver of this process. """Create rpc receiver of this process.
...@@ -127,19 +162,20 @@ def create_receiver(max_queue_size, net_type): ...@@ -127,19 +162,20 @@ def create_receiver(max_queue_size, net_type):
net_type : str net_type : str
Networking type. Current options are: 'socket', 'tensorpipe'. Networking type. Current options are: 'socket', 'tensorpipe'.
""" """
max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0')) max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0"))
_CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type, max_thread_count) _CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type, max_thread_count)
def finalize_sender(): def finalize_sender():
"""Finalize rpc sender of this process. """Finalize rpc sender of this process."""
"""
_CAPI_DGLRPCFinalizeSender() _CAPI_DGLRPCFinalizeSender()
def finalize_receiver(): def finalize_receiver():
"""Finalize rpc receiver of this process. """Finalize rpc receiver of this process."""
"""
_CAPI_DGLRPCFinalizeReceiver() _CAPI_DGLRPCFinalizeReceiver()
def wait_for_senders(ip_addr, port, num_senders, blocking=True): def wait_for_senders(ip_addr, port, num_senders, blocking=True):
"""Wait all of the senders' connections. """Wait all of the senders' connections.
...@@ -158,6 +194,7 @@ def wait_for_senders(ip_addr, port, num_senders, blocking=True): ...@@ -158,6 +194,7 @@ def wait_for_senders(ip_addr, port, num_senders, blocking=True):
""" """
_CAPI_DGLRPCWaitForSenders(ip_addr, int(port), int(num_senders), blocking) _CAPI_DGLRPCWaitForSenders(ip_addr, int(port), int(num_senders), blocking)
def connect_receiver(ip_addr, port, recv_id, group_id=-1): def connect_receiver(ip_addr, port, recv_id, group_id=-1):
"""Connect to target receiver """Connect to target receiver
...@@ -170,11 +207,14 @@ def connect_receiver(ip_addr, port, recv_id, group_id=-1): ...@@ -170,11 +207,14 @@ def connect_receiver(ip_addr, port, recv_id, group_id=-1):
recv_id : int recv_id : int
receiver's ID receiver's ID
""" """
target_id = recv_id if group_id == -1 else register_client(recv_id, group_id) target_id = (
recv_id if group_id == -1 else register_client(recv_id, group_id)
)
if target_id < 0: if target_id < 0:
raise DGLError("Invalid target id: {}".format(target_id)) raise DGLError("Invalid target id: {}".format(target_id))
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id)) return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id))
def connect_receiver_finalize(max_try_times): def connect_receiver_finalize(max_try_times):
"""Finalize the action to connect to receivers. Make sure that either all connections are """Finalize the action to connect to receivers. Make sure that either all connections are
successfully established or connection fails. successfully established or connection fails.
...@@ -189,6 +229,7 @@ def connect_receiver_finalize(max_try_times): ...@@ -189,6 +229,7 @@ def connect_receiver_finalize(max_try_times):
""" """
return _CAPI_DGLRPCConnectReceiverFinalize(max_try_times) return _CAPI_DGLRPCConnectReceiverFinalize(max_try_times)
def set_rank(rank): def set_rank(rank):
"""Set the rank of this process. """Set the rank of this process.
...@@ -202,6 +243,7 @@ def set_rank(rank): ...@@ -202,6 +243,7 @@ def set_rank(rank):
""" """
_CAPI_DGLRPCSetRank(int(rank)) _CAPI_DGLRPCSetRank(int(rank))
def get_rank(): def get_rank():
"""Get the rank of this process. """Get the rank of this process.
...@@ -215,6 +257,7 @@ def get_rank(): ...@@ -215,6 +257,7 @@ def get_rank():
""" """
return _CAPI_DGLRPCGetRank() return _CAPI_DGLRPCGetRank()
def set_machine_id(machine_id): def set_machine_id(machine_id):
"""Set current machine ID """Set current machine ID
...@@ -225,6 +268,7 @@ def set_machine_id(machine_id): ...@@ -225,6 +268,7 @@ def set_machine_id(machine_id):
""" """
_CAPI_DGLRPCSetMachineID(int(machine_id)) _CAPI_DGLRPCSetMachineID(int(machine_id))
def get_machine_id(): def get_machine_id():
"""Get current machine ID """Get current machine ID
...@@ -235,6 +279,7 @@ def get_machine_id(): ...@@ -235,6 +279,7 @@ def get_machine_id():
""" """
return _CAPI_DGLRPCGetMachineID() return _CAPI_DGLRPCGetMachineID()
def set_num_machines(num_machines): def set_num_machines(num_machines):
"""Set number of machine """Set number of machine
...@@ -245,6 +290,7 @@ def set_num_machines(num_machines): ...@@ -245,6 +290,7 @@ def set_num_machines(num_machines):
""" """
_CAPI_DGLRPCSetNumMachines(int(num_machines)) _CAPI_DGLRPCSetNumMachines(int(num_machines))
def get_num_machines(): def get_num_machines():
"""Get number of machines """Get number of machines
...@@ -255,36 +301,37 @@ def get_num_machines(): ...@@ -255,36 +301,37 @@ def get_num_machines():
""" """
return _CAPI_DGLRPCGetNumMachines() return _CAPI_DGLRPCGetNumMachines()
def set_num_server(num_server): def set_num_server(num_server):
"""Set the total number of server. """Set the total number of server."""
"""
_CAPI_DGLRPCSetNumServer(int(num_server)) _CAPI_DGLRPCSetNumServer(int(num_server))
def get_num_server(): def get_num_server():
"""Get the total number of server. """Get the total number of server."""
"""
return _CAPI_DGLRPCGetNumServer() return _CAPI_DGLRPCGetNumServer()
def set_num_client(num_client): def set_num_client(num_client):
"""Set the total number of client. """Set the total number of client."""
"""
_CAPI_DGLRPCSetNumClient(int(num_client)) _CAPI_DGLRPCSetNumClient(int(num_client))
def get_num_client(): def get_num_client():
"""Get the total number of client. """Get the total number of client."""
"""
return _CAPI_DGLRPCGetNumClient() return _CAPI_DGLRPCGetNumClient()
def set_num_server_per_machine(num_server): def set_num_server_per_machine(num_server):
"""Set the total number of server per machine """Set the total number of server per machine"""
"""
_CAPI_DGLRPCSetNumServerPerMachine(num_server) _CAPI_DGLRPCSetNumServerPerMachine(num_server)
def get_num_server_per_machine(): def get_num_server_per_machine():
"""Get the total number of server per machine """Get the total number of server per machine"""
"""
return _CAPI_DGLRPCGetNumServerPerMachine() return _CAPI_DGLRPCGetNumServerPerMachine()
def incr_msg_seq(): def incr_msg_seq():
"""Increment the message sequence number and return the old one. """Increment the message sequence number and return the old one.
...@@ -295,6 +342,7 @@ def incr_msg_seq(): ...@@ -295,6 +342,7 @@ def incr_msg_seq():
""" """
return _CAPI_DGLRPCIncrMsgSeq() return _CAPI_DGLRPCIncrMsgSeq()
def get_msg_seq(): def get_msg_seq():
"""Get the current message sequence number. """Get the current message sequence number.
...@@ -305,6 +353,7 @@ def get_msg_seq(): ...@@ -305,6 +353,7 @@ def get_msg_seq():
""" """
return _CAPI_DGLRPCGetMsgSeq() return _CAPI_DGLRPCGetMsgSeq()
def set_msg_seq(msg_seq): def set_msg_seq(msg_seq):
"""Set the current message sequence number. """Set the current message sequence number.
...@@ -315,6 +364,7 @@ def set_msg_seq(msg_seq): ...@@ -315,6 +364,7 @@ def set_msg_seq(msg_seq):
""" """
_CAPI_DGLRPCSetMsgSeq(int(msg_seq)) _CAPI_DGLRPCSetMsgSeq(int(msg_seq))
def register_service(service_id, req_cls, res_cls=None): def register_service(service_id, req_cls, res_cls=None):
"""Register a service to RPC. """Register a service to RPC.
...@@ -332,6 +382,7 @@ def register_service(service_id, req_cls, res_cls=None): ...@@ -332,6 +382,7 @@ def register_service(service_id, req_cls, res_cls=None):
RESPONSE_CLASS_TO_SERVICE_ID[res_cls] = service_id RESPONSE_CLASS_TO_SERVICE_ID[res_cls] = service_id
SERVICE_ID_TO_PROPERTY[service_id] = (req_cls, res_cls) SERVICE_ID_TO_PROPERTY[service_id] = (req_cls, res_cls)
def get_service_property(service_id): def get_service_property(service_id):
"""Get service property. """Get service property.
...@@ -347,6 +398,7 @@ def get_service_property(service_id): ...@@ -347,6 +398,7 @@ def get_service_property(service_id):
""" """
return SERVICE_ID_TO_PROPERTY[service_id] return SERVICE_ID_TO_PROPERTY[service_id]
class Request: class Request:
"""Base request class""" """Base request class"""
...@@ -389,9 +441,14 @@ class Request: ...@@ -389,9 +441,14 @@ class Request:
cls = self.__class__ cls = self.__class__
sid = REQUEST_CLASS_TO_SERVICE_ID.get(cls, None) sid = REQUEST_CLASS_TO_SERVICE_ID.get(cls, None)
if sid is None: if sid is None:
raise DGLError('Request class {} has not been registered as a service.'.format(cls)) raise DGLError(
"Request class {} has not been registered as a service.".format(
cls
)
)
return sid return sid
class Response: class Response:
"""Base response class""" """Base response class"""
...@@ -417,9 +474,14 @@ class Response: ...@@ -417,9 +474,14 @@ class Response:
cls = self.__class__ cls = self.__class__
sid = RESPONSE_CLASS_TO_SERVICE_ID.get(cls, None) sid = RESPONSE_CLASS_TO_SERVICE_ID.get(cls, None)
if sid is None: if sid is None:
raise DGLError('Response class {} has not been registered as a service.'.format(cls)) raise DGLError(
"Response class {} has not been registered as a service.".format(
cls
)
)
return sid return sid
def serialize_to_payload(serializable): def serialize_to_payload(serializable):
"""Serialize an object to payloads. """Serialize an object to payloads.
...@@ -452,11 +514,14 @@ def serialize_to_payload(serializable): ...@@ -452,11 +514,14 @@ def serialize_to_payload(serializable):
data = bytearray(pickle.dumps((nonarray_pos, nonarray_state))) data = bytearray(pickle.dumps((nonarray_pos, nonarray_state)))
return data, array_state return data, array_state
class PlaceHolder: class PlaceHolder:
"""PlaceHolder object for deserialization""" """PlaceHolder object for deserialization"""
_PLACEHOLDER = PlaceHolder() _PLACEHOLDER = PlaceHolder()
def deserialize_from_payload(cls, data, tensors): def deserialize_from_payload(cls, data, tensors):
"""Deserialize and reconstruct the object from payload. """Deserialize and reconstruct the object from payload.
...@@ -496,7 +561,8 @@ def deserialize_from_payload(cls, data, tensors): ...@@ -496,7 +561,8 @@ def deserialize_from_payload(cls, data, tensors):
obj.__setstate__(state) obj.__setstate__(state)
return obj return obj
@register_object('rpc.RPCMessage')
@register_object("rpc.RPCMessage")
class RPCMessage(ObjectBase): class RPCMessage(ObjectBase):
"""Serialized RPC message that can be sent to remote processes. """Serialized RPC message that can be sent to remote processes.
...@@ -519,7 +585,17 @@ class RPCMessage(ObjectBase): ...@@ -519,7 +585,17 @@ class RPCMessage(ObjectBase):
group_id : int group_id : int
The group ID The group ID
""" """
def __init__(self, service_id, msg_seq, client_id, server_id, data, tensors, group_id=0):
def __init__(
self,
service_id,
msg_seq,
client_id,
server_id,
data,
tensors,
group_id=0,
):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLRPCCreateRPCMessage, _CAPI_DGLRPCCreateRPCMessage,
int(service_id), int(service_id),
...@@ -528,7 +604,8 @@ class RPCMessage(ObjectBase): ...@@ -528,7 +604,8 @@ class RPCMessage(ObjectBase):
int(server_id), int(server_id),
data, data,
[F.zerocopy_to_dgl_ndarray(tsor) for tsor in tensors], [F.zerocopy_to_dgl_ndarray(tsor) for tsor in tensors],
int(group_id)) int(group_id),
)
@property @property
def service_id(self): def service_id(self):
...@@ -566,6 +643,7 @@ class RPCMessage(ObjectBase): ...@@ -566,6 +643,7 @@ class RPCMessage(ObjectBase):
"""Get group ID.""" """Get group ID."""
return _CAPI_DGLRPCMessageGetGroupId(self) return _CAPI_DGLRPCMessageGetGroupId(self)
def send_request(target, request): def send_request(target, request):
"""Send one request to the target server. """Send one request to the target server.
...@@ -593,10 +671,18 @@ def send_request(target, request): ...@@ -593,10 +671,18 @@ def send_request(target, request):
client_id = get_rank() client_id = get_rank()
server_id = target server_id = target
data, tensors = serialize_to_payload(request) data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, msg = RPCMessage(
data, tensors, group_id=get_group_id()) service_id,
msg_seq,
client_id,
server_id,
data,
tensors,
group_id=get_group_id(),
)
send_rpc_message(msg, server_id) send_rpc_message(msg, server_id)
def send_request_to_machine(target, request): def send_request_to_machine(target, request):
"""Send one request to the target machine, which will randomly """Send one request to the target machine, which will randomly
select a server node to process this request. select a server node to process this request.
...@@ -620,12 +706,17 @@ def send_request_to_machine(target, request): ...@@ -620,12 +706,17 @@ def send_request_to_machine(target, request):
service_id = request.service_id service_id = request.service_id
msg_seq = incr_msg_seq() msg_seq = incr_msg_seq()
client_id = get_rank() client_id = get_rank()
server_id = random.randint(target*get_num_server_per_machine(), server_id = random.randint(
(target+1)*get_num_server_per_machine()-1) target * get_num_server_per_machine(),
(target + 1) * get_num_server_per_machine() - 1,
)
data, tensors = serialize_to_payload(request) data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id()) msg = RPCMessage(
service_id, msg_seq, client_id, server_id, data, tensors, get_group_id()
)
send_rpc_message(msg, server_id) send_rpc_message(msg, server_id)
def send_response(target, response, group_id): def send_response(target, response, group_id):
"""Send one response to the target client. """Send one response to the target client.
...@@ -655,9 +746,12 @@ def send_response(target, response, group_id): ...@@ -655,9 +746,12 @@ def send_response(target, response, group_id):
client_id = target client_id = target
server_id = get_rank() server_id = get_rank()
data, tensors = serialize_to_payload(response) data, tensors = serialize_to_payload(response)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, group_id) msg = RPCMessage(
service_id, msg_seq, client_id, server_id, data, tensors, group_id
)
send_rpc_message(msg, get_client(client_id, group_id)) send_rpc_message(msg, get_client(client_id, group_id))
def recv_request(timeout=0): def recv_request(timeout=0):
"""Receive one request. """Receive one request.
...@@ -690,14 +784,19 @@ def recv_request(timeout=0): ...@@ -690,14 +784,19 @@ def recv_request(timeout=0):
set_msg_seq(msg.msg_seq) set_msg_seq(msg.msg_seq)
req_cls, _ = SERVICE_ID_TO_PROPERTY[msg.service_id] req_cls, _ = SERVICE_ID_TO_PROPERTY[msg.service_id]
if req_cls is None: if req_cls is None:
raise DGLError('Got request message from service ID {}, ' raise DGLError(
'but no request class is registered.'.format(msg.service_id)) "Got request message from service ID {}, "
"but no request class is registered.".format(msg.service_id)
)
req = deserialize_from_payload(req_cls, msg.data, msg.tensors) req = deserialize_from_payload(req_cls, msg.data, msg.tensors)
if msg.server_id != get_rank(): if msg.server_id != get_rank():
raise DGLError('Got request sent to server {}, ' raise DGLError(
'different from my rank {}!'.format(msg.server_id, get_rank())) "Got request sent to server {}, "
"different from my rank {}!".format(msg.server_id, get_rank())
)
return req, msg.client_id, msg.group_id return req, msg.client_id, msg.group_id
def recv_response(timeout=0): def recv_response(timeout=0):
"""Receive one response. """Receive one response.
...@@ -725,17 +824,24 @@ def recv_response(timeout=0): ...@@ -725,17 +824,24 @@ def recv_response(timeout=0):
return None return None
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id] _, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None: if res_cls is None:
raise DGLError('Got response message from service ID {}, ' raise DGLError(
'but no response class is registered.'.format(msg.service_id)) "Got response message from service ID {}, "
"but no response class is registered.".format(msg.service_id)
)
res = deserialize_from_payload(res_cls, msg.data, msg.tensors) res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != get_rank() and get_rank() != -1: if msg.client_id != get_rank() and get_rank() != -1:
raise DGLError('Got response of request sent by client {}, ' raise DGLError(
'different from my rank {}!'.format(msg.client_id, get_rank())) "Got response of request sent by client {}, "
"different from my rank {}!".format(msg.client_id, get_rank())
)
if msg.group_id != get_group_id(): if msg.group_id != get_group_id():
raise DGLError("Got response of request sent by group {}, " raise DGLError(
"different from my group {}!".format(msg.group_id, get_group_id())) "Got response of request sent by group {}, "
"different from my group {}!".format(msg.group_id, get_group_id())
)
return res return res
def remote_call(target_and_requests, timeout=0): def remote_call(target_and_requests, timeout=0):
"""Invoke registered services on remote servers and collect responses. """Invoke registered services on remote servers and collect responses.
...@@ -771,10 +877,20 @@ def remote_call(target_and_requests, timeout=0): ...@@ -771,10 +877,20 @@ def remote_call(target_and_requests, timeout=0):
service_id = request.service_id service_id = request.service_id
msg_seq = incr_msg_seq() msg_seq = incr_msg_seq()
client_id = get_rank() client_id = get_rank()
server_id = random.randint(target*get_num_server_per_machine(), server_id = random.randint(
(target+1)*get_num_server_per_machine()-1) target * get_num_server_per_machine(),
(target + 1) * get_num_server_per_machine() - 1,
)
data, tensors = serialize_to_payload(request) data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id()) msg = RPCMessage(
service_id,
msg_seq,
client_id,
server_id,
data,
tensors,
get_group_id(),
)
send_rpc_message(msg, server_id) send_rpc_message(msg, server_id)
# check if has response # check if has response
res_cls = get_service_property(service_id)[1] res_cls = get_service_property(service_id)[1]
...@@ -786,22 +902,28 @@ def remote_call(target_and_requests, timeout=0): ...@@ -786,22 +902,28 @@ def remote_call(target_and_requests, timeout=0):
msg = recv_rpc_message(timeout) msg = recv_rpc_message(timeout)
if msg is None: if msg is None:
raise DGLError( raise DGLError(
f"Timed out for receiving message within {timeout} milliseconds") f"Timed out for receiving message within {timeout} milliseconds"
)
num_res -= 1 num_res -= 1
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id] _, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None: if res_cls is None:
raise DGLError('Got response message from service ID {}, ' raise DGLError(
'but no response class is registered.'.format(msg.service_id)) "Got response message from service ID {}, "
"but no response class is registered.".format(msg.service_id)
)
res = deserialize_from_payload(res_cls, msg.data, msg.tensors) res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != myrank: if msg.client_id != myrank:
raise DGLError('Got reponse of request sent by client {}, ' raise DGLError(
'different from my rank {}!'.format(msg.client_id, myrank)) "Got reponse of request sent by client {}, "
"different from my rank {}!".format(msg.client_id, myrank)
)
# set response # set response
all_res[msgseq2pos[msg.msg_seq]] = res all_res[msgseq2pos[msg.msg_seq]] = res
return all_res return all_res
def send_requests_to_machine(target_and_requests): def send_requests_to_machine(target_and_requests):
""" Send requests to the remote machines. """Send requests to the remote machines.
This operation isn't block. It returns immediately once it sends all requests. This operation isn't block. It returns immediately once it sends all requests.
...@@ -824,10 +946,20 @@ def send_requests_to_machine(target_and_requests): ...@@ -824,10 +946,20 @@ def send_requests_to_machine(target_and_requests):
msg_seq = incr_msg_seq() msg_seq = incr_msg_seq()
client_id = get_rank() client_id = get_rank()
server_id = random.randint(target*get_num_server_per_machine(), server_id = random.randint(
(target+1)*get_num_server_per_machine()-1) target * get_num_server_per_machine(),
(target + 1) * get_num_server_per_machine() - 1,
)
data, tensors = serialize_to_payload(request) data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors, get_group_id()) msg = RPCMessage(
service_id,
msg_seq,
client_id,
server_id,
data,
tensors,
get_group_id(),
)
send_rpc_message(msg, server_id) send_rpc_message(msg, server_id)
# check if has response # check if has response
res_cls = get_service_property(service_id)[1] res_cls = get_service_property(service_id)[1]
...@@ -835,8 +967,9 @@ def send_requests_to_machine(target_and_requests): ...@@ -835,8 +967,9 @@ def send_requests_to_machine(target_and_requests):
msgseq2pos[msg_seq] = pos msgseq2pos[msg_seq] = pos
return msgseq2pos return msgseq2pos
def recv_responses(msgseq2pos, timeout=0): def recv_responses(msgseq2pos, timeout=0):
""" Receive responses """Receive responses
It returns the responses in the same order as the requests. The order of requests It returns the responses in the same order as the requests. The order of requests
are stored in msgseq2pos. are stored in msgseq2pos.
...@@ -866,20 +999,26 @@ def recv_responses(msgseq2pos, timeout=0): ...@@ -866,20 +999,26 @@ def recv_responses(msgseq2pos, timeout=0):
msg = recv_rpc_message(timeout) msg = recv_rpc_message(timeout)
if msg is None: if msg is None:
raise DGLError( raise DGLError(
f"Timed out for receiving message within {timeout} milliseconds") f"Timed out for receiving message within {timeout} milliseconds"
)
num_res -= 1 num_res -= 1
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id] _, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None: if res_cls is None:
raise DGLError('Got response message from service ID {}, ' raise DGLError(
'but no response class is registered.'.format(msg.service_id)) "Got response message from service ID {}, "
"but no response class is registered.".format(msg.service_id)
)
res = deserialize_from_payload(res_cls, msg.data, msg.tensors) res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != myrank: if msg.client_id != myrank:
raise DGLError('Got reponse of request sent by client {}, ' raise DGLError(
'different from my rank {}!'.format(msg.client_id, myrank)) "Got reponse of request sent by client {}, "
"different from my rank {}!".format(msg.client_id, myrank)
)
# set response # set response
all_res[msgseq2pos[msg.msg_seq]] = res all_res[msgseq2pos[msg.msg_seq]] = res
return all_res return all_res
def remote_call_to_machine(target_and_requests, timeout=0): def remote_call_to_machine(target_and_requests, timeout=0):
"""Invoke registered services on remote machine """Invoke registered services on remote machine
(which will ramdom select a server to process the request) and collect responses. (which will ramdom select a server to process the request) and collect responses.
...@@ -910,6 +1049,7 @@ def remote_call_to_machine(target_and_requests, timeout=0): ...@@ -910,6 +1049,7 @@ def remote_call_to_machine(target_and_requests, timeout=0):
msgseq2pos = send_requests_to_machine(target_and_requests) msgseq2pos = send_requests_to_machine(target_and_requests)
return recv_responses(msgseq2pos, timeout) return recv_responses(msgseq2pos, timeout)
def send_rpc_message(msg, target): def send_rpc_message(msg, target):
"""Send one message to the target server. """Send one message to the target server.
...@@ -936,6 +1076,7 @@ def send_rpc_message(msg, target): ...@@ -936,6 +1076,7 @@ def send_rpc_message(msg, target):
""" """
_CAPI_DGLRPCSendRPCMessage(msg, int(target)) _CAPI_DGLRPCSendRPCMessage(msg, int(target))
def recv_rpc_message(timeout=0): def recv_rpc_message(timeout=0):
"""Receive one message. """Receive one message.
...@@ -960,23 +1101,34 @@ def recv_rpc_message(timeout=0): ...@@ -960,23 +1101,34 @@ def recv_rpc_message(timeout=0):
status = _CAPI_DGLRPCRecvRPCMessage(timeout, msg) status = _CAPI_DGLRPCRecvRPCMessage(timeout, msg)
return msg if status == 0 else None return msg if status == 0 else None
def client_barrier(): def client_barrier():
"""Barrier all client processes""" """Barrier all client processes"""
req = ClientBarrierRequest() req = ClientBarrierRequest()
send_request(0, req) send_request(0, req)
res = recv_response() res = recv_response()
assert res.msg == 'barrier' assert res.msg == "barrier"
def finalize_server(): def finalize_server():
"""Finalize resources of current server """Finalize resources of current server"""
"""
finalize_sender() finalize_sender()
finalize_receiver() finalize_receiver()
print("Server (%d) shutdown." % get_rank()) print("Server (%d) shutdown." % get_rank())
def fast_pull(name, id_tensor, part_id, service_id,
machine_count, group_count, machine_id, def fast_pull(
client_id, local_data, policy): name,
id_tensor,
part_id,
service_id,
machine_count,
group_count,
machine_id,
client_id,
local_data,
policy,
):
"""Fast-pull api used by kvstore. """Fast-pull api used by kvstore.
Parameters Parameters
...@@ -1004,39 +1156,45 @@ def fast_pull(name, id_tensor, part_id, service_id, ...@@ -1004,39 +1156,45 @@ def fast_pull(name, id_tensor, part_id, service_id,
""" """
msg_seq = incr_msg_seq() msg_seq = incr_msg_seq()
pickle_data = bytearray(pickle.dumps(([0], [name]))) pickle_data = bytearray(pickle.dumps(([0], [name])))
global_id = _CAPI_DGLRPCGetGlobalIDFromLocalPartition(F.zerocopy_to_dgl_ndarray(id_tensor), global_id = _CAPI_DGLRPCGetGlobalIDFromLocalPartition(
F.zerocopy_to_dgl_ndarray(part_id), F.zerocopy_to_dgl_ndarray(id_tensor),
machine_id) F.zerocopy_to_dgl_ndarray(part_id),
machine_id,
)
global_id = F.zerocopy_from_dgl_ndarray(global_id) global_id = F.zerocopy_from_dgl_ndarray(global_id)
g2l_id = policy.to_local(global_id) g2l_id = policy.to_local(global_id)
res_tensor = _CAPI_DGLRPCFastPull(name, res_tensor = _CAPI_DGLRPCFastPull(
int(machine_id), name,
int(machine_count), int(machine_id),
int(group_count), int(machine_count),
int(client_id), int(group_count),
int(service_id), int(client_id),
int(msg_seq), int(service_id),
pickle_data, int(msg_seq),
F.zerocopy_to_dgl_ndarray(id_tensor), pickle_data,
F.zerocopy_to_dgl_ndarray(part_id), F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(g2l_id), F.zerocopy_to_dgl_ndarray(part_id),
F.zerocopy_to_dgl_ndarray(local_data)) F.zerocopy_to_dgl_ndarray(g2l_id),
F.zerocopy_to_dgl_ndarray(local_data),
)
return F.zerocopy_from_dgl_ndarray(res_tensor) return F.zerocopy_from_dgl_ndarray(res_tensor)
def register_sig_handler(): def register_sig_handler():
"""Register for handling signal event. """Register for handling signal event."""
"""
_CAPI_DGLRPCHandleSignal() _CAPI_DGLRPCHandleSignal()
def copy_data_to_shared_memory(dst, source): def copy_data_to_shared_memory(dst, source):
"""Copy tensor data to shared-memory tensor """Copy tensor data to shared-memory tensor"""
"""
F.zerocopy_to_dgl_ndarray(dst).copyfrom(F.zerocopy_to_dgl_ndarray(source)) F.zerocopy_to_dgl_ndarray(dst).copyfrom(F.zerocopy_to_dgl_ndarray(source))
############### Some basic services will be defined here ############# ############### Some basic services will be defined here #############
CLIENT_REGISTER = 22451 CLIENT_REGISTER = 22451
class ClientRegisterRequest(Request): class ClientRegisterRequest(Request):
"""This request will send client's ip to server. """This request will send client's ip to server.
...@@ -1045,6 +1203,7 @@ class ClientRegisterRequest(Request): ...@@ -1045,6 +1203,7 @@ class ClientRegisterRequest(Request):
ip_addr : str ip_addr : str
client's IP address client's IP address
""" """
def __init__(self, ip_addr): def __init__(self, ip_addr):
self.ip_addr = ip_addr self.ip_addr = ip_addr
...@@ -1055,7 +1214,8 @@ class ClientRegisterRequest(Request): ...@@ -1055,7 +1214,8 @@ class ClientRegisterRequest(Request):
self.ip_addr = state self.ip_addr = state
def process_request(self, server_state): def process_request(self, server_state):
return None # do nothing return None # do nothing
class ClientRegisterResponse(Response): class ClientRegisterResponse(Response):
"""This response will send assigned ID to client. """This response will send assigned ID to client.
...@@ -1065,6 +1225,7 @@ class ClientRegisterResponse(Response): ...@@ -1065,6 +1225,7 @@ class ClientRegisterResponse(Response):
ID : int ID : int
client's ID client's ID
""" """
def __init__(self, client_id): def __init__(self, client_id):
self.client_id = client_id self.client_id = client_id
...@@ -1077,6 +1238,7 @@ class ClientRegisterResponse(Response): ...@@ -1077,6 +1238,7 @@ class ClientRegisterResponse(Response):
SHUT_DOWN_SERVER = 22452 SHUT_DOWN_SERVER = 22452
class ShutDownRequest(Request): class ShutDownRequest(Request):
"""Client send this request to shut-down a server. """Client send this request to shut-down a server.
...@@ -1087,6 +1249,7 @@ class ShutDownRequest(Request): ...@@ -1087,6 +1249,7 @@ class ShutDownRequest(Request):
client_id : int client_id : int
client's ID client's ID
""" """
def __init__(self, client_id, force_shutdown_server=False): def __init__(self, client_id, force_shutdown_server=False):
self.client_id = client_id self.client_id = client_id
self.force_shutdown_server = force_shutdown_server self.force_shutdown_server = force_shutdown_server
...@@ -1104,8 +1267,10 @@ class ShutDownRequest(Request): ...@@ -1104,8 +1267,10 @@ class ShutDownRequest(Request):
finalize_server() finalize_server()
return SERVER_EXIT return SERVER_EXIT
GET_NUM_CLIENT = 22453 GET_NUM_CLIENT = 22453
class GetNumberClientsResponse(Response): class GetNumberClientsResponse(Response):
"""This reponse will send total number of clients. """This reponse will send total number of clients.
...@@ -1114,6 +1279,7 @@ class GetNumberClientsResponse(Response): ...@@ -1114,6 +1279,7 @@ class GetNumberClientsResponse(Response):
num_client : int num_client : int
total number of clients total number of clients
""" """
def __init__(self, num_client): def __init__(self, num_client):
self.num_client = num_client self.num_client = num_client
...@@ -1123,6 +1289,7 @@ class GetNumberClientsResponse(Response): ...@@ -1123,6 +1289,7 @@ class GetNumberClientsResponse(Response):
def __setstate__(self, state): def __setstate__(self, state):
self.num_client = state self.num_client = state
class GetNumberClientsRequest(Request): class GetNumberClientsRequest(Request):
"""Client send this request to get the total number of client. """Client send this request to get the total number of client.
...@@ -1131,6 +1298,7 @@ class GetNumberClientsRequest(Request): ...@@ -1131,6 +1298,7 @@ class GetNumberClientsRequest(Request):
client_id : int client_id : int
client's ID client's ID
""" """
def __init__(self, client_id): def __init__(self, client_id):
self.client_id = client_id self.client_id = client_id
...@@ -1144,8 +1312,10 @@ class GetNumberClientsRequest(Request): ...@@ -1144,8 +1312,10 @@ class GetNumberClientsRequest(Request):
res = GetNumberClientsResponse(get_num_client()) res = GetNumberClientsResponse(get_num_client())
return res return res
CLIENT_BARRIER = 22454 CLIENT_BARRIER = 22454
class ClientBarrierResponse(Response): class ClientBarrierResponse(Response):
"""Send the barrier confirmation to client """Send the barrier confirmation to client
...@@ -1154,7 +1324,8 @@ class ClientBarrierResponse(Response): ...@@ -1154,7 +1324,8 @@ class ClientBarrierResponse(Response):
msg : str msg : str
string msg string msg
""" """
def __init__(self, msg='barrier'):
def __init__(self, msg="barrier"):
self.msg = msg self.msg = msg
def __getstate__(self): def __getstate__(self):
...@@ -1163,6 +1334,7 @@ class ClientBarrierResponse(Response): ...@@ -1163,6 +1334,7 @@ class ClientBarrierResponse(Response):
def __setstate__(self, state): def __setstate__(self, state):
self.msg = state self.msg = state
class ClientBarrierRequest(Request): class ClientBarrierRequest(Request):
"""Send the barrier information to server """Send the barrier information to server
...@@ -1171,7 +1343,8 @@ class ClientBarrierRequest(Request): ...@@ -1171,7 +1343,8 @@ class ClientBarrierRequest(Request):
msg : str msg : str
string msg string msg
""" """
def __init__(self, msg='barrier'):
def __init__(self, msg="barrier"):
self.msg = msg self.msg = msg
self.group_id = get_group_id() self.group_id = get_group_id()
...@@ -1182,7 +1355,9 @@ class ClientBarrierRequest(Request): ...@@ -1182,7 +1355,9 @@ class ClientBarrierRequest(Request):
self.msg, self.group_id = state self.msg, self.group_id = state
def process_request(self, server_state): def process_request(self, server_state):
_CAPI_DGLRPCSetBarrierCount(_CAPI_DGLRPCGetBarrierCount(self.group_id)+1, self.group_id) _CAPI_DGLRPCSetBarrierCount(
_CAPI_DGLRPCGetBarrierCount(self.group_id) + 1, self.group_id
)
if _CAPI_DGLRPCGetBarrierCount(self.group_id) == get_num_client(): if _CAPI_DGLRPCGetBarrierCount(self.group_id) == get_num_client():
_CAPI_DGLRPCSetBarrierCount(0, self.group_id) _CAPI_DGLRPCSetBarrierCount(0, self.group_id)
res_list = [] res_list = []
...@@ -1191,6 +1366,7 @@ class ClientBarrierRequest(Request): ...@@ -1191,6 +1366,7 @@ class ClientBarrierRequest(Request):
return res_list return res_list
return None return None
def set_group_id(group_id): def set_group_id(group_id):
"""Set current group ID """Set current group ID
...@@ -1201,6 +1377,7 @@ def set_group_id(group_id): ...@@ -1201,6 +1377,7 @@ def set_group_id(group_id):
""" """
_CAPI_DGLRPCSetGroupID(int(group_id)) _CAPI_DGLRPCSetGroupID(int(group_id))
def get_group_id(): def get_group_id():
"""Get current group ID """Get current group ID
...@@ -1211,6 +1388,7 @@ def get_group_id(): ...@@ -1211,6 +1388,7 @@ def get_group_id():
""" """
return _CAPI_DGLRPCGetGroupID() return _CAPI_DGLRPCGetGroupID()
def register_client(client_id, group_id): def register_client(client_id, group_id):
"""Register client """Register client
...@@ -1221,6 +1399,7 @@ def register_client(client_id, group_id): ...@@ -1221,6 +1399,7 @@ def register_client(client_id, group_id):
""" """
return _CAPI_DGLRPCRegisterClient(int(client_id), int(group_id)) return _CAPI_DGLRPCRegisterClient(int(client_id), int(group_id))
def get_client(client_id, group_id): def get_client(client_id, group_id):
"""Get global client ID """Get global client ID
...@@ -1238,6 +1417,7 @@ def get_client(client_id, group_id): ...@@ -1238,6 +1417,7 @@ def get_client(client_id, group_id):
""" """
return _CAPI_DGLRPCGetClient(int(client_id), int(group_id)) return _CAPI_DGLRPCGetClient(int(client_id), int(group_id))
class DistConnectError(DGLError): class DistConnectError(DGLError):
"""Exception raised for errors if fail to connect peer. """Exception raised for errors if fail to connect peer.
...@@ -1247,12 +1427,16 @@ class DistConnectError(DGLError): ...@@ -1247,12 +1427,16 @@ class DistConnectError(DGLError):
reference for KVServer reference for KVServer
""" """
def __init__(self, max_try_times, ip='', port=''): def __init__(self, max_try_times, ip="", port=""):
peer_str = "peer[{}:{}]".format(ip, port) if ip != '' else "peer" peer_str = "peer[{}:{}]".format(ip, port) if ip != "" else "peer"
self.message = "Failed to build conncetion with {} after {} retries. " \ self.message = (
"Please check network availability or increase max try " \ "Failed to build conncetion with {} after {} retries. "
"times via 'DGL_DIST_MAX_TRY_TIMES'.".format( "Please check network availability or increase max try "
peer_str, max_try_times) "times via 'DGL_DIST_MAX_TRY_TIMES'.".format(
peer_str, max_try_times
)
)
super().__init__(self.message) super().__init__(self.message)
_init_api("dgl.distributed.rpc") _init_api("dgl.distributed.rpc")
"""Functions used by client.""" """Functions used by client."""
import os
import socket
import atexit import atexit
import logging import logging
import os
import socket
import time import time
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE
if os.name != 'nt': if os.name != "nt":
import fcntl import fcntl
import struct import struct
def local_ip4_addr_list(): def local_ip4_addr_list():
"""Return a set of IPv4 address """Return a set of IPv4 address
...@@ -20,21 +21,25 @@ def local_ip4_addr_list(): ...@@ -20,21 +21,25 @@ def local_ip4_addr_list():
`logging.getLogger("dgl-distributed-socket").setLevel(logging.WARNING+1)` `logging.getLogger("dgl-distributed-socket").setLevel(logging.WARNING+1)`
to disable the warning here to disable the warning here
""" """
assert os.name != 'nt', 'Do not support Windows rpc yet.' assert os.name != "nt", "Do not support Windows rpc yet."
nic = set() nic = set()
logger = logging.getLogger("dgl-distributed-socket") logger = logging.getLogger("dgl-distributed-socket")
for if_nidx in socket.if_nameindex(): for if_nidx in socket.if_nameindex():
name = if_nidx[1] name = if_nidx[1]
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try: try:
ip_of_ni = fcntl.ioctl(sock.fileno(), ip_of_ni = fcntl.ioctl(
0x8915, # SIOCGIFADDR sock.fileno(),
struct.pack('256s', name[:15].encode("UTF-8"))) 0x8915, # SIOCGIFADDR
struct.pack("256s", name[:15].encode("UTF-8")),
)
except OSError as e: except OSError as e:
if e.errno == 99: # EADDRNOTAVAIL if e.errno == 99: # EADDRNOTAVAIL
logger.warning( logger.warning(
"Warning! Interface: %s \n" "Warning! Interface: %s \n"
"IP address not available for interface.", name) "IP address not available for interface.",
name,
)
continue continue
raise e raise e
...@@ -42,6 +47,7 @@ def local_ip4_addr_list(): ...@@ -42,6 +47,7 @@ def local_ip4_addr_list():
nic.add(ip_addr) nic.add(ip_addr)
return nic return nic
def get_local_machine_id(server_namebook): def get_local_machine_id(server_namebook):
"""Given server_namebook, find local machine ID """Given server_namebook, find local machine ID
...@@ -76,6 +82,7 @@ def get_local_machine_id(server_namebook): ...@@ -76,6 +82,7 @@ def get_local_machine_id(server_namebook):
break break
return res return res
def get_local_usable_addr(probe_addr): def get_local_usable_addr(probe_addr):
"""Get local usable IP and port """Get local usable IP and port
...@@ -90,7 +97,7 @@ def get_local_usable_addr(probe_addr): ...@@ -90,7 +97,7 @@ def get_local_usable_addr(probe_addr):
sock.connect((probe_addr, 1)) sock.connect((probe_addr, 1))
ip_addr = sock.getsockname()[0] ip_addr = sock.getsockname()[0]
except ValueError: except ValueError:
ip_addr = '127.0.0.1' ip_addr = "127.0.0.1"
finally: finally:
sock.close() sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
...@@ -99,11 +106,16 @@ def get_local_usable_addr(probe_addr): ...@@ -99,11 +106,16 @@ def get_local_usable_addr(probe_addr):
port = sock.getsockname()[1] port = sock.getsockname()[1]
sock.close() sock.close()
return ip_addr + ':' + str(port) return ip_addr + ":" + str(port)
def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, def connect_to_server(
net_type='socket', group_id=0): ip_config,
num_servers,
max_queue_size=MAX_QUEUE_SIZE,
net_type="socket",
group_id=0,
):
"""Connect this client to server. """Connect this client to server.
Parameters Parameters
...@@ -127,23 +139,30 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, ...@@ -127,23 +139,30 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
------ ------
ConnectionError : If anything wrong with the connection. ConnectionError : If anything wrong with the connection.
""" """
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers assert num_servers > 0, (
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size "num_servers (%d) must be a positive number." % num_servers
assert net_type in ('socket', 'tensorpipe'), \ )
'net_type (%s) can only be \'socket\' or \'tensorpipe\'.' % net_type assert max_queue_size > 0, (
"queue_size (%d) cannot be a negative number." % max_queue_size
)
assert net_type in ("socket", "tensorpipe"), (
"net_type (%s) can only be 'socket' or 'tensorpipe'." % net_type
)
# Register some basic service # Register some basic service
rpc.register_service(rpc.CLIENT_REGISTER, rpc.register_service(
rpc.ClientRegisterRequest, rpc.CLIENT_REGISTER,
rpc.ClientRegisterResponse) rpc.ClientRegisterRequest,
rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ClientRegisterResponse,
rpc.ShutDownRequest, )
None) rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ShutDownRequest, None)
rpc.register_service(rpc.GET_NUM_CLIENT, rpc.register_service(
rpc.GetNumberClientsRequest, rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsResponse) rpc.GetNumberClientsRequest,
rpc.register_service(rpc.CLIENT_BARRIER, rpc.GetNumberClientsResponse,
rpc.ClientBarrierRequest, )
rpc.ClientBarrierResponse) rpc.register_service(
rpc.CLIENT_BARRIER, rpc.ClientBarrierRequest, rpc.ClientBarrierResponse
)
rpc.register_sig_handler() rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers) server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook) num_servers = len(server_namebook)
...@@ -157,7 +176,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, ...@@ -157,7 +176,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
if server_info[0] > max_machine_id: if server_info[0] > max_machine_id:
max_machine_id = server_info[0] max_machine_id = server_info[0]
rpc.set_num_server_per_machine(group_count[0]) rpc.set_num_server_per_machine(group_count[0])
num_machines = max_machine_id+1 num_machines = max_machine_id + 1
rpc.set_num_machines(num_machines) rpc.set_num_machines(num_machines)
machine_id = get_local_machine_id(server_namebook) machine_id = get_local_machine_id(server_namebook)
rpc.set_machine_id(machine_id) rpc.set_machine_id(machine_id)
...@@ -165,7 +184,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, ...@@ -165,7 +184,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
rpc.create_sender(max_queue_size, net_type) rpc.create_sender(max_queue_size, net_type)
rpc.create_receiver(max_queue_size, net_type) rpc.create_receiver(max_queue_size, net_type)
# Get connected with all server nodes # Get connected with all server nodes
max_try_times = int(os.environ.get('DGL_DIST_MAX_TRY_TIMES', 1024)) max_try_times = int(os.environ.get("DGL_DIST_MAX_TRY_TIMES", 1024))
for server_id, addr in server_namebook.items(): for server_id, addr in server_namebook.items():
server_ip = addr[1] server_ip = addr[1]
server_port = addr[2] server_port = addr[2]
...@@ -173,39 +192,50 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, ...@@ -173,39 +192,50 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
while not rpc.connect_receiver(server_ip, server_port, server_id): while not rpc.connect_receiver(server_ip, server_port, server_id):
try_times += 1 try_times += 1
if try_times % 200 == 0: if try_times % 200 == 0:
print("Client is trying to connect server receiver: {}:{}".format( print(
server_ip, server_port)) "Client is trying to connect server receiver: {}:{}".format(
server_ip, server_port
)
)
if try_times >= max_try_times: if try_times >= max_try_times:
raise rpc.DistConnectError(max_try_times, server_ip, server_port) raise rpc.DistConnectError(
max_try_times, server_ip, server_port
)
time.sleep(3) time.sleep(3)
if not rpc.connect_receiver_finalize(max_try_times): if not rpc.connect_receiver_finalize(max_try_times):
raise rpc.DistConnectError(max_try_times) raise rpc.DistConnectError(max_try_times)
# Get local usable IP address and port # Get local usable IP address and port
ip_addr = get_local_usable_addr(server_ip) ip_addr = get_local_usable_addr(server_ip)
client_ip, client_port = ip_addr.split(':') client_ip, client_port = ip_addr.split(":")
# Register client on server # Register client on server
register_req = rpc.ClientRegisterRequest(ip_addr) register_req = rpc.ClientRegisterRequest(ip_addr)
for server_id in range(num_servers): for server_id in range(num_servers):
rpc.send_request(server_id, register_req) rpc.send_request(server_id, register_req)
# wait server connect back # wait server connect back
rpc.wait_for_senders(client_ip, client_port, num_servers, rpc.wait_for_senders(
blocking=net_type == 'socket') client_ip, client_port, num_servers, blocking=net_type == "socket"
print("Client [{}] waits on {}:{}".format( )
os.getpid(), client_ip, client_port)) print(
"Client [{}] waits on {}:{}".format(os.getpid(), client_ip, client_port)
)
# recv client ID from server # recv client ID from server
res = rpc.recv_response() res = rpc.recv_response()
rpc.set_rank(res.client_id) rpc.set_rank(res.client_id)
print("Machine (%d) group (%d) client (%d) connect to server successfuly!" \ print(
% (machine_id, group_id, rpc.get_rank())) "Machine (%d) group (%d) client (%d) connect to server successfuly!"
% (machine_id, group_id, rpc.get_rank())
)
# get total number of client # get total number of client
get_client_num_req = rpc.GetNumberClientsRequest(rpc.get_rank()) get_client_num_req = rpc.GetNumberClientsRequest(rpc.get_rank())
rpc.send_request(0, get_client_num_req) rpc.send_request(0, get_client_num_req)
res = rpc.recv_response() res = rpc.recv_response()
rpc.set_num_client(res.num_client) rpc.set_num_client(res.num_client)
from .dist_context import exit_client, set_initialized from .dist_context import exit_client, set_initialized
atexit.register(exit_client) atexit.register(exit_client)
set_initialized(True) set_initialized(True)
def shutdown_servers(ip_config, num_servers): def shutdown_servers(ip_config, num_servers):
"""Issue commands to remote servers to shut them down. """Issue commands to remote servers to shut them down.
...@@ -229,13 +259,11 @@ def shutdown_servers(ip_config, num_servers): ...@@ -229,13 +259,11 @@ def shutdown_servers(ip_config, num_servers):
------ ------
ConnectionError : If anything wrong with the connection. ConnectionError : If anything wrong with the connection.
""" """
rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ShutDownRequest, None)
rpc.ShutDownRequest,
None)
rpc.register_sig_handler() rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers) server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook) num_servers = len(server_namebook)
rpc.create_sender(MAX_QUEUE_SIZE, 'tensorpipe') rpc.create_sender(MAX_QUEUE_SIZE, "tensorpipe")
# Get connected with all server nodes # Get connected with all server nodes
for server_id, addr in server_namebook.items(): for server_id, addr in server_namebook.items():
server_ip = addr[1] server_ip = addr[1]
......
"""Functions used by server.""" """Functions used by server."""
import time
import os import os
import time
from ..base import DGLError from ..base import DGLError
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE
def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
max_queue_size=MAX_QUEUE_SIZE, net_type='socket'): def start_server(
server_id,
ip_config,
num_servers,
num_clients,
server_state,
max_queue_size=MAX_QUEUE_SIZE,
net_type="socket",
):
"""Start DGL server, which will be shared with all the rpc services. """Start DGL server, which will be shared with all the rpc services.
This is a blocking function -- it returns only when the server shutdown. This is a blocking function -- it returns only when the server shutdown.
...@@ -34,33 +43,47 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -34,33 +43,47 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
net_type : str net_type : str
Networking type. Current options are: ``'socket'`` or ``'tensorpipe'``. Networking type. Current options are: ``'socket'`` or ``'tensorpipe'``.
""" """
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id assert server_id >= 0, (
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers "server_id (%d) cannot be a negative number." % server_id
assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_clients )
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size assert num_servers > 0, (
assert net_type in ('socket', 'tensorpipe'), \ "num_servers (%d) must be a positive number." % num_servers
'net_type (%s) can only be \'socket\' or \'tensorpipe\'' % net_type )
assert num_clients >= 0, (
"num_client (%d) cannot be a negative number." % num_clients
)
assert max_queue_size > 0, (
"queue_size (%d) cannot be a negative number." % max_queue_size
)
assert net_type in ("socket", "tensorpipe"), (
"net_type (%s) can only be 'socket' or 'tensorpipe'" % net_type
)
if server_state.keep_alive: if server_state.keep_alive:
assert net_type == 'tensorpipe', \ assert (
"net_type can only be 'tensorpipe' if 'keep_alive' is enabled." net_type == "tensorpipe"
print("As configured, this server will keep alive for multiple" ), "net_type can only be 'tensorpipe' if 'keep_alive' is enabled."
" client groups until force shutdown request is received." print(
" [WARNING] This feature is experimental and not fully tested.") "As configured, this server will keep alive for multiple"
" client groups until force shutdown request is received."
" [WARNING] This feature is experimental and not fully tested."
)
# Register signal handler. # Register signal handler.
rpc.register_sig_handler() rpc.register_sig_handler()
# Register some basic services # Register some basic services
rpc.register_service(rpc.CLIENT_REGISTER, rpc.register_service(
rpc.ClientRegisterRequest, rpc.CLIENT_REGISTER,
rpc.ClientRegisterResponse) rpc.ClientRegisterRequest,
rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ClientRegisterResponse,
rpc.ShutDownRequest, )
None) rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ShutDownRequest, None)
rpc.register_service(rpc.GET_NUM_CLIENT, rpc.register_service(
rpc.GetNumberClientsRequest, rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsResponse) rpc.GetNumberClientsRequest,
rpc.register_service(rpc.CLIENT_BARRIER, rpc.GetNumberClientsResponse,
rpc.ClientBarrierRequest, )
rpc.ClientBarrierResponse) rpc.register_service(
rpc.CLIENT_BARRIER, rpc.ClientBarrierRequest, rpc.ClientBarrierResponse
)
rpc.set_rank(server_id) rpc.set_rank(server_id)
server_namebook = rpc.read_ip_config(ip_config, num_servers) server_namebook = rpc.read_ip_config(ip_config, num_servers)
machine_id = server_namebook[server_id][0] machine_id = server_namebook[server_id][0]
...@@ -73,9 +96,11 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -73,9 +96,11 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# Once all the senders connect to server, server will not # Once all the senders connect to server, server will not
# accept new sender's connection # accept new sender's connection
print( print(
"Server is waiting for connections on [{}:{}]...".format(ip_addr, port)) "Server is waiting for connections on [{}:{}]...".format(ip_addr, port)
rpc.wait_for_senders(ip_addr, port, num_clients, )
blocking=net_type == 'socket') rpc.wait_for_senders(
ip_addr, port, num_clients, blocking=net_type == "socket"
)
rpc.set_num_client(num_clients) rpc.set_num_client(num_clients)
recv_clients = {} recv_clients = {}
while True: while True:
...@@ -89,18 +114,25 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -89,18 +114,25 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# a new client group is ready # a new client group is ready
ips.sort() ips.sort()
client_namebook = dict(enumerate(ips)) client_namebook = dict(enumerate(ips))
time.sleep(3) # wait for clients' receivers ready time.sleep(3) # wait for clients' receivers ready
max_try_times = int(os.environ.get('DGL_DIST_MAX_TRY_TIMES', 120)) max_try_times = int(os.environ.get("DGL_DIST_MAX_TRY_TIMES", 120))
for client_id, addr in client_namebook.items(): for client_id, addr in client_namebook.items():
client_ip, client_port = addr.split(':') client_ip, client_port = addr.split(":")
try_times = 0 try_times = 0
while not rpc.connect_receiver(client_ip, client_port, client_id, group_id): while not rpc.connect_receiver(
client_ip, client_port, client_id, group_id
):
try_times += 1 try_times += 1
if try_times % 200 == 0: if try_times % 200 == 0:
print("Server~{} is trying to connect client receiver: {}:{}".format( print(
server_id, client_ip, client_port)) "Server~{} is trying to connect client receiver: {}:{}".format(
server_id, client_ip, client_port
)
)
if try_times >= max_try_times: if try_times >= max_try_times:
raise rpc.DistConnectError(max_try_times, client_ip, client_port) raise rpc.DistConnectError(
max_try_times, client_ip, client_port
)
time.sleep(1) time.sleep(1)
if not rpc.connect_receiver_finalize(max_try_times): if not rpc.connect_receiver_finalize(max_try_times):
raise rpc.DistConnectError(max_try_times) raise rpc.DistConnectError(max_try_times)
...@@ -130,7 +162,11 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -130,7 +162,11 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
print("Server is exiting...") print("Server is exiting...")
return return
elif res == SERVER_KEEP_ALIVE: elif res == SERVER_KEEP_ALIVE:
print("Server keeps alive while client group~{} is exiting...".format(group_id)) print(
"Server keeps alive while client group~{} is exiting...".format(
group_id
)
)
else: else:
raise DGLError("Unexpected response: {}".format(res)) raise DGLError("Unexpected response: {}".format(res))
else: else:
......
...@@ -77,4 +77,5 @@ class ServerState: ...@@ -77,4 +77,5 @@ class ServerState:
"""Flag of whether keep alive""" """Flag of whether keep alive"""
return self._keep_alive return self._keep_alive
_init_api("dgl.distributed.server_state") _init_api("dgl.distributed.server_state")
"""Define utility functions for shared memory.""" """Define utility functions for shared memory."""
from .. import backend as F from .. import backend as F
from .._ffi.ndarray import empty_shared_mem
from .. import ndarray as nd from .. import ndarray as nd
from .._ffi.ndarray import empty_shared_mem
DTYPE_DICT = F.data_type_dict DTYPE_DICT = F.data_type_dict
DTYPE_DICT = {DTYPE_DICT[key]:key for key in DTYPE_DICT} DTYPE_DICT = {DTYPE_DICT[key]: key for key in DTYPE_DICT}
def _get_ndata_path(graph_name, ndata_name): def _get_ndata_path(graph_name, ndata_name):
return "/" + graph_name + "_node_" + ndata_name return "/" + graph_name + "_node_" + ndata_name
def _get_edata_path(graph_name, edata_name): def _get_edata_path(graph_name, edata_name):
return "/" + graph_name + "_edge_" + edata_name return "/" + graph_name + "_edge_" + edata_name
def _to_shared_mem(arr, name): def _to_shared_mem(arr, name):
dlpack = F.zerocopy_to_dlpack(arr) dlpack = F.zerocopy_to_dlpack(arr)
dgl_tensor = nd.from_dlpack(dlpack) dgl_tensor = nd.from_dlpack(dlpack)
new_arr = empty_shared_mem(name, True, F.shape(arr), DTYPE_DICT[F.dtype(arr)]) new_arr = empty_shared_mem(
name, True, F.shape(arr), DTYPE_DICT[F.dtype(arr)]
)
dgl_tensor.copyto(new_arr) dgl_tensor.copyto(new_arr)
dlpack = new_arr.to_dlpack() dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack) return F.zerocopy_from_dlpack(dlpack)
...@@ -10,6 +10,7 @@ from .init import zero_initializer ...@@ -10,6 +10,7 @@ from .init import zero_initializer
from .storages import TensorStorage from .storages import TensorStorage
from .utils import gather_pinned_tensor_rows, pin_memory_inplace from .utils import gather_pinned_tensor_rows, pin_memory_inplace
class _LazyIndex(object): class _LazyIndex(object):
def __init__(self, index): def __init__(self, index):
if isinstance(index, list): if isinstance(index, list):
...@@ -21,17 +22,17 @@ class _LazyIndex(object): ...@@ -21,17 +22,17 @@ class _LazyIndex(object):
return len(self._indices[-1]) return len(self._indices[-1])
def slice(self, index): def slice(self, index):
""" Create a new _LazyIndex object sliced by the given index tensor. """Create a new _LazyIndex object sliced by the given index tensor."""
"""
# if our indices are in the same context, lets just slice now and free # if our indices are in the same context, lets just slice now and free
# memory, otherwise do nothing until we have to # memory, otherwise do nothing until we have to
if F.context(self._indices[-1]) == F.context(index): if F.context(self._indices[-1]) == F.context(index):
return _LazyIndex(self._indices[:-1] + [F.gather_row(self._indices[-1], index)]) return _LazyIndex(
self._indices[:-1] + [F.gather_row(self._indices[-1], index)]
)
return _LazyIndex(self._indices + [index]) return _LazyIndex(self._indices + [index])
def flatten(self): def flatten(self):
""" Evaluate the chain of indices, and return a single index tensor. """Evaluate the chain of indices, and return a single index tensor."""
"""
flat_index = self._indices[0] flat_index = self._indices[0]
# here we actually need to resolve it # here we actually need to resolve it
for index in self._indices[1:]: for index in self._indices[1:]:
...@@ -40,6 +41,7 @@ class _LazyIndex(object): ...@@ -40,6 +41,7 @@ class _LazyIndex(object):
flat_index = F.gather_row(flat_index, index) flat_index = F.gather_row(flat_index, index)
return flat_index return flat_index
class LazyFeature(object): class LazyFeature(object):
"""Placeholder for feature prefetching. """Placeholder for feature prefetching.
...@@ -81,12 +83,16 @@ class LazyFeature(object): ...@@ -81,12 +83,16 @@ class LazyFeature(object):
id_ : Tensor, optional id_ : Tensor, optional
The ID tensor. The ID tensor.
""" """
__slots__ = ['name', 'id_']
__slots__ = ["name", "id_"]
def __init__(self, name=None, id_=None): def __init__(self, name=None, id_=None):
self.name = name self.name = name
self.id_ = id_ self.id_ = id_
def to(self, *args, **kwargs): # pylint: disable=invalid-name, unused-argument def to(
self, *args, **kwargs
): # pylint: disable=invalid-name, unused-argument
"""No-op. For compatibility of :meth:`Frame.to` method.""" """No-op. For compatibility of :meth:`Frame.to` method."""
return self return self
...@@ -104,7 +110,8 @@ class LazyFeature(object): ...@@ -104,7 +110,8 @@ class LazyFeature(object):
def record_stream(self, stream): def record_stream(self, stream):
"""No-op. For compatibility of :meth:`Frame.record_stream` method.""" """No-op. For compatibility of :meth:`Frame.record_stream` method."""
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
class Scheme(namedtuple("Scheme", ["shape", "dtype"])):
"""The column scheme. """The column scheme.
Parameters Parameters
...@@ -114,6 +121,7 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): ...@@ -114,6 +121,7 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
dtype : backend-specific type object dtype : backend-specific type object
The feature data type. The feature data type.
""" """
# Pickling torch dtypes could be problemetic; this is a workaround. # Pickling torch dtypes could be problemetic; this is a workaround.
# I also have to create data_type_dict and reverse_data_type_dict # I also have to create data_type_dict and reverse_data_type_dict
# attribute just for this bug. # attribute just for this bug.
...@@ -128,6 +136,7 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): ...@@ -128,6 +136,7 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
dtype = F.data_type_dict[dtype_str] dtype = F.data_type_dict[dtype_str]
return cls(shape, dtype) return cls(shape, dtype)
def infer_scheme(tensor): def infer_scheme(tensor):
"""Infer column scheme from the given tensor data. """Infer column scheme from the given tensor data.
...@@ -143,6 +152,7 @@ def infer_scheme(tensor): ...@@ -143,6 +152,7 @@ def infer_scheme(tensor):
""" """
return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor)) return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor))
class Column(TensorStorage): class Column(TensorStorage):
"""A column is a compact store of features of multiple nodes/edges. """A column is a compact store of features of multiple nodes/edges.
...@@ -185,6 +195,7 @@ class Column(TensorStorage): ...@@ -185,6 +195,7 @@ class Column(TensorStorage):
index : Tensor index : Tensor
Index tensor Index tensor
""" """
def __init__(self, storage, *args, **kwargs): def __init__(self, storage, *args, **kwargs):
super().__init__(storage) super().__init__(storage)
self._init(*args, **kwargs) self._init(*args, **kwargs)
...@@ -212,8 +223,14 @@ class Column(TensorStorage): ...@@ -212,8 +223,14 @@ class Column(TensorStorage):
index_ctx = F.context(self.index) index_ctx = F.context(self.index)
# If under the special case where the storage is pinned and the index is on # If under the special case where the storage is pinned and the index is on
# CUDA, directly call UVA slicing (even if they aree not in the same context). # CUDA, directly call UVA slicing (even if they aree not in the same context).
if storage_ctx != index_ctx and storage_ctx == F.cpu() and F.is_pinned(self.storage): if (
self.storage = gather_pinned_tensor_rows(self.storage, self.index) storage_ctx != index_ctx
and storage_ctx == F.cpu()
and F.is_pinned(self.storage)
):
self.storage = gather_pinned_tensor_rows(
self.storage, self.index
)
else: else:
# If index and storage is not in the same context, # If index and storage is not in the same context,
# copy index to the same context of storage. # copy index to the same context of storage.
...@@ -228,7 +245,9 @@ class Column(TensorStorage): ...@@ -228,7 +245,9 @@ class Column(TensorStorage):
# move data to the right device # move data to the right device
if self.device is not None: if self.device is not None:
self.storage = F.copy_to(self.storage, self.device[0], **self.device[1]) self.storage = F.copy_to(
self.storage, self.device[0], **self.device[1]
)
self.device = None self.device = None
# convert data to the right type # convert data to the right type
...@@ -247,8 +266,8 @@ class Column(TensorStorage): ...@@ -247,8 +266,8 @@ class Column(TensorStorage):
self._data_nd = None # should unpin data if it was pinned. self._data_nd = None # should unpin data if it was pinned.
self.pinned_by_dgl = False self.pinned_by_dgl = False
def to(self, device, **kwargs): # pylint: disable=invalid-name def to(self, device, **kwargs): # pylint: disable=invalid-name
""" Return a new column with columns copy to the targeted device (cpu/gpu). """Return a new column with columns copy to the targeted device (cpu/gpu).
Parameters Parameters
---------- ----------
...@@ -268,13 +287,13 @@ class Column(TensorStorage): ...@@ -268,13 +287,13 @@ class Column(TensorStorage):
@property @property
def dtype(self): def dtype(self):
""" Return the effective data type of this Column """ """Return the effective data type of this Column"""
if self.deferred_dtype is not None: if self.deferred_dtype is not None:
return self.deferred_dtype return self.deferred_dtype
return self.storage.dtype return self.storage.dtype
def astype(self, new_dtype): def astype(self, new_dtype):
""" Return a new column such that when its data is requested, """Return a new column such that when its data is requested,
it will be converted to new_dtype. it will be converted to new_dtype.
Parameters Parameters
...@@ -353,8 +372,10 @@ class Column(TensorStorage): ...@@ -353,8 +372,10 @@ class Column(TensorStorage):
""" """
feat_scheme = infer_scheme(feats) feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme: if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s." raise DGLError(
% (feat_scheme, self.scheme)) "Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme)
)
self.data = F.scatter_row(self.data, rowids, feats) self.data = F.scatter_row(self.data, rowids, feats)
def extend(self, feats, feat_scheme=None): def extend(self, feats, feat_scheme=None):
...@@ -373,14 +394,22 @@ class Column(TensorStorage): ...@@ -373,14 +394,22 @@ class Column(TensorStorage):
feat_scheme = infer_scheme(feats) feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme: if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s." raise DGLError(
% (feat_scheme, self.scheme)) "Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme)
)
self.data = F.cat([self.data, feats], dim=0) self.data = F.cat([self.data, feats], dim=0)
def clone(self): def clone(self):
"""Return a shallow copy of this column.""" """Return a shallow copy of this column."""
return Column(self.storage, self.scheme, self.index, self.device, self.deferred_dtype) return Column(
self.storage,
self.scheme,
self.index,
self.device,
self.deferred_dtype,
)
def deepclone(self): def deepclone(self):
"""Return a deepcopy of this column. """Return a deepcopy of this column.
...@@ -409,13 +438,25 @@ class Column(TensorStorage): ...@@ -409,13 +438,25 @@ class Column(TensorStorage):
Sub-column Sub-column
""" """
if self.index is None: if self.index is None:
return Column(self.storage, self.scheme, rowids, self.device, self.deferred_dtype) return Column(
self.storage,
self.scheme,
rowids,
self.device,
self.deferred_dtype,
)
else: else:
index = self.index index = self.index
if not isinstance(index, _LazyIndex): if not isinstance(index, _LazyIndex):
index = _LazyIndex(self.index) index = _LazyIndex(self.index)
index = index.slice(rowids) index = index.slice(rowids)
return Column(self.storage, self.scheme, index, self.device, self.deferred_dtype) return Column(
self.storage,
self.scheme,
index,
self.device,
self.deferred_dtype,
)
@staticmethod @staticmethod
def create(data): def create(data):
...@@ -435,30 +476,32 @@ class Column(TensorStorage): ...@@ -435,30 +476,32 @@ class Column(TensorStorage):
state = self.__dict__.copy() state = self.__dict__.copy()
# data pinning does not get serialized, so we need to remove that from # data pinning does not get serialized, so we need to remove that from
# the state # the state
state['_data_nd'] = None state["_data_nd"] = None
state['pinned_by_dgl'] = False state["pinned_by_dgl"] = False
return state return state
def __setstate__(self, state): def __setstate__(self, state):
index = None index = None
device = None device = None
if 'storage' in state and state['storage'] is not None: if "storage" in state and state["storage"] is not None:
assert 'index' not in state or state['index'] is None assert "index" not in state or state["index"] is None
assert 'device' not in state or state['device'] is None assert "device" not in state or state["device"] is None
else: else:
# we may have a column with only index information, and that is # we may have a column with only index information, and that is
# valid # valid
index = None if 'index' not in state else state['index'] index = None if "index" not in state else state["index"]
device = None if 'device' not in state else state['device'] device = None if "device" not in state else state["device"]
assert 'deferred_dtype' not in state or state['deferred_dtype'] is None assert "deferred_dtype" not in state or state["deferred_dtype"] is None
assert 'pinned_by_dgl' not in state or state['pinned_by_dgl'] is False assert "pinned_by_dgl" not in state or state["pinned_by_dgl"] is False
assert '_data_nd' not in state or state['_data_nd'] is None assert "_data_nd" not in state or state["_data_nd"] is None
self.__dict__ = state self.__dict__ = state
# properly initialize this object # properly initialize this object
self._init(self.scheme if hasattr(self, 'scheme') else None, self._init(
index=index, self.scheme if hasattr(self, "scheme") else None,
device=device) index=index,
device=device,
)
def _init(self, scheme=None, index=None, device=None, deferred_dtype=None): def _init(self, scheme=None, index=None, device=None, deferred_dtype=None):
self.scheme = scheme if scheme else infer_scheme(self.storage) self.scheme = scheme if scheme else infer_scheme(self.storage)
...@@ -472,7 +515,7 @@ class Column(TensorStorage): ...@@ -472,7 +515,7 @@ class Column(TensorStorage):
return self.clone() return self.clone()
def fetch(self, indices, device, pin_memory=False, **kwargs): def fetch(self, indices, device, pin_memory=False, **kwargs):
_ = self.data # materialize in case of lazy slicing & data transfer _ = self.data # materialize in case of lazy slicing & data transfer
return super().fetch(indices, device, pin_memory=pin_memory, **kwargs) return super().fetch(indices, device, pin_memory=pin_memory, **kwargs)
def pin_memory_(self): def pin_memory_(self):
...@@ -503,10 +546,11 @@ class Column(TensorStorage): ...@@ -503,10 +546,11 @@ class Column(TensorStorage):
---------- ----------
stream : torch.cuda.Stream. stream : torch.cuda.Stream.
""" """
if F.get_preferred_backend() != 'pytorch': if F.get_preferred_backend() != "pytorch":
raise DGLError("record_stream only supports the PyTorch backend.") raise DGLError("record_stream only supports the PyTorch backend.")
self.data.record_stream(stream) self.data.record_stream(stream)
class Frame(MutableMapping): class Frame(MutableMapping):
"""The columnar storage for node/edge features. """The columnar storage for node/edge features.
...@@ -523,6 +567,7 @@ class Frame(MutableMapping): ...@@ -523,6 +567,7 @@ class Frame(MutableMapping):
The number of rows in this frame. If ``data`` is provided and is not empty, The number of rows in this frame. If ``data`` is provided and is not empty,
``num_rows`` will be ignored and inferred from the given data. ``num_rows`` will be ignored and inferred from the given data.
""" """
def __init__(self, data=None, num_rows=None): def __init__(self, data=None, num_rows=None):
if data is None: if data is None:
self._columns = dict() self._columns = dict()
...@@ -531,8 +576,10 @@ class Frame(MutableMapping): ...@@ -531,8 +576,10 @@ class Frame(MutableMapping):
assert not isinstance(data, Frame) # sanity check for code refactor assert not isinstance(data, Frame) # sanity check for code refactor
# Note that we always create a new column for the given data. # Note that we always create a new column for the given data.
# This avoids two frames accidentally sharing the same column. # This avoids two frames accidentally sharing the same column.
self._columns = {k : v if isinstance(v, LazyFeature) else Column.create(v) self._columns = {
for k, v in data.items()} k: v if isinstance(v, LazyFeature) else Column.create(v)
for k, v in data.items()
}
self._num_rows = num_rows self._num_rows = num_rows
# infer num_rows & sanity check # infer num_rows & sanity check
for name, col in self._columns.items(): for name, col in self._columns.items():
...@@ -541,8 +588,10 @@ class Frame(MutableMapping): ...@@ -541,8 +588,10 @@ class Frame(MutableMapping):
if self._num_rows is None: if self._num_rows is None:
self._num_rows = len(col) self._num_rows = len(col)
elif len(col) != self._num_rows: elif len(col) != self._num_rows:
raise DGLError('Expected all columns to have same # rows (%d), ' raise DGLError(
'got %d on %r.' % (self._num_rows, len(col), name)) "Expected all columns to have same # rows (%d), "
"got %d on %r." % (self._num_rows, len(col), name)
)
# Initializer for empty values. Initializer is a callable. # Initializer for empty values. Initializer is a callable.
# If is none, then a warning will be raised # If is none, then a warning will be raised
...@@ -590,7 +639,7 @@ class Frame(MutableMapping): ...@@ -590,7 +639,7 @@ class Frame(MutableMapping):
@property @property
def schemes(self): def schemes(self):
"""Return a dictionary of column name to column schemes.""" """Return a dictionary of column name to column schemes."""
return {k : col.scheme for k, col in self._columns.items()} return {k: col.scheme for k, col in self._columns.items()}
@property @property
def num_columns(self): def num_columns(self):
...@@ -658,14 +707,21 @@ class Frame(MutableMapping): ...@@ -658,14 +707,21 @@ class Frame(MutableMapping):
The column context. The column context.
""" """
if name in self: if name in self:
dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name) dgl_warning(
'Column "%s" already exists. Ignore adding this column again.'
% name
)
return return
if self.get_initializer(name) is None: if self.get_initializer(name) is None:
self._set_zero_default_initializer() self._set_zero_default_initializer()
initializer = self.get_initializer(name) initializer = self.get_initializer(name)
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype, init_data = initializer(
ctx, slice(0, self.num_rows)) (self.num_rows,) + scheme.shape,
scheme.dtype,
ctx,
slice(0, self.num_rows),
)
self._columns[name] = Column(init_data, scheme) self._columns[name] = Column(init_data, scheme)
def add_rows(self, num_rows): def add_rows(self, num_rows):
...@@ -686,8 +742,12 @@ class Frame(MutableMapping): ...@@ -686,8 +742,12 @@ class Frame(MutableMapping):
if self.get_initializer(key) is None: if self.get_initializer(key) is None:
self._set_zero_default_initializer() self._set_zero_default_initializer()
initializer = self.get_initializer(key) initializer = self.get_initializer(key)
new_data = initializer((num_rows,) + scheme.shape, scheme.dtype, new_data = initializer(
ctx, slice(self._num_rows, self._num_rows + num_rows)) (num_rows,) + scheme.shape,
scheme.dtype,
ctx,
slice(self._num_rows, self._num_rows + num_rows),
)
feat_placeholders[key] = new_data feat_placeholders[key] = new_data
self._append(Frame(feat_placeholders)) self._append(Frame(feat_placeholders))
self._num_rows += num_rows self._num_rows += num_rows
...@@ -708,8 +768,10 @@ class Frame(MutableMapping): ...@@ -708,8 +768,10 @@ class Frame(MutableMapping):
col = Column.create(data) col = Column.create(data)
if len(col) != self.num_rows: if len(col) != self.num_rows:
raise DGLError('Expected data to have %d rows, got %d.' % raise DGLError(
(self.num_rows, len(col))) "Expected data to have %d rows, got %d."
% (self.num_rows, len(col))
)
self._columns[name] = col self._columns[name] = col
def update_row(self, rowids, data): def update_row(self, rowids, data):
...@@ -747,9 +809,12 @@ class Frame(MutableMapping): ...@@ -747,9 +809,12 @@ class Frame(MutableMapping):
if self.get_initializer(key) is None: if self.get_initializer(key) is None:
self._set_zero_default_initializer() self._set_zero_default_initializer()
initializer = self.get_initializer(key) initializer = self.get_initializer(key)
new_data = initializer((other.num_rows,) + scheme.shape, new_data = initializer(
scheme.dtype, ctx, (other.num_rows,) + scheme.shape,
slice(self._num_rows, self._num_rows + other.num_rows)) scheme.dtype,
ctx,
slice(self._num_rows, self._num_rows + other.num_rows),
)
other[key] = new_data other[key] = new_data
# append other to self # append other to self
for key, col in other._columns.items(): for key, col in other._columns.items():
...@@ -829,8 +894,10 @@ class Frame(MutableMapping): ...@@ -829,8 +894,10 @@ class Frame(MutableMapping):
Frame Frame
A deep-cloned frame. A deep-cloned frame.
""" """
newframe = Frame({k : col.deepclone() for k, col in self._columns.items()}, newframe = Frame(
self._num_rows) {k: col.deepclone() for k, col in self._columns.items()},
self._num_rows,
)
newframe._initializers = self._initializers newframe._initializers = self._initializers
newframe._default_initializer = self._default_initializer newframe._default_initializer = self._default_initializer
return newframe return newframe
...@@ -851,14 +918,14 @@ class Frame(MutableMapping): ...@@ -851,14 +918,14 @@ class Frame(MutableMapping):
Frame Frame
A new subframe. A new subframe.
""" """
subcols = {k : col.subcolumn(rowids) for k, col in self._columns.items()} subcols = {k: col.subcolumn(rowids) for k, col in self._columns.items()}
subf = Frame(subcols, len(rowids)) subf = Frame(subcols, len(rowids))
subf._initializers = self._initializers subf._initializers = self._initializers
subf._default_initializer = self._default_initializer subf._default_initializer = self._default_initializer
return subf return subf
def to(self, device, **kwargs): # pylint: disable=invalid-name def to(self, device, **kwargs): # pylint: disable=invalid-name
""" Return a new frame with columns copy to the targeted device (cpu/gpu). """Return a new frame with columns copy to the targeted device (cpu/gpu).
Parameters Parameters
---------- ----------
...@@ -873,7 +940,10 @@ class Frame(MutableMapping): ...@@ -873,7 +940,10 @@ class Frame(MutableMapping):
A new frame A new frame
""" """
newframe = self.clone() newframe = self.clone()
new_columns = {key : col.to(device, **kwargs) for key, col in newframe._columns.items()} new_columns = {
key: col.to(device, **kwargs)
for key, col in newframe._columns.items()
}
newframe._columns = new_columns newframe._columns = new_columns
return newframe return newframe
...@@ -899,8 +969,11 @@ class Frame(MutableMapping): ...@@ -899,8 +969,11 @@ class Frame(MutableMapping):
column.record_stream(stream) column.record_stream(stream)
def _astype_float(self, new_type): def _astype_float(self, new_type):
assert new_type in [F.float64, F.float32, F.float16], \ assert new_type in [
"'new_type' must be floating-point type: %s" % str(new_type) F.float64,
F.float32,
F.float16,
], "'new_type' must be floating-point type: %s" % str(new_type)
newframe = self.clone() newframe = self.clone()
new_columns = {} new_columns = {}
for name, column in self._columns.items(): for name, column in self._columns.items():
...@@ -913,16 +986,16 @@ class Frame(MutableMapping): ...@@ -913,16 +986,16 @@ class Frame(MutableMapping):
return newframe return newframe
def half(self): def half(self):
""" Return a new frame with all floating-point columns converted """Return a new frame with all floating-point columns converted
to half-precision (float16) """ to half-precision (float16)"""
return self._astype_float(F.float16) return self._astype_float(F.float16)
def float(self): def float(self):
""" Return a new frame with all floating-point columns converted """Return a new frame with all floating-point columns converted
to single-precision (float32) """ to single-precision (float32)"""
return self._astype_float(F.float32) return self._astype_float(F.float32)
def double(self): def double(self):
""" Return a new frame with all floating-point columns converted """Return a new frame with all floating-point columns converted
to double-precision (float64) """ to double-precision (float64)"""
return self._astype_float(F.float64) return self._astype_float(F.float64)
...@@ -2,6 +2,6 @@ ...@@ -2,6 +2,6 @@
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from __future__ import absolute_import from __future__ import absolute_import
from .base import *
from .message import * from .message import *
from .reducer import * from .reducer import *
from .base import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment