Unverified Commit 0b3a447b authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

auto format distributed (#5317)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 74c9d27d
"""Define distributed graph.""" """Define distributed graph."""
from collections.abc import MutableMapping import gc
from collections import namedtuple
import os import os
import gc from collections import namedtuple
from collections.abc import MutableMapping
import numpy as np import numpy as np
from ..heterograph import DGLGraph from .. import backend as F, heterograph_index
from ..convert import heterograph as dgl_heterograph
from ..convert import graph as dgl_graph
from ..transforms import compact_graphs
from .. import heterograph_index
from .. import backend as F
from ..base import NID, EID, ETYPE, ALL, is_all, DGLError
from .kvstore import KVServer, get_kvstore
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from ..ndarray import exist_shared_mem_array from ..base import ALL, DGLError, EID, ETYPE, is_all, NID
from ..convert import graph as dgl_graph, heterograph as dgl_heterograph
from ..frame import infer_scheme from ..frame import infer_scheme
from .partition import load_partition, load_partition_feats, load_partition_book
from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book from ..heterograph import DGLGraph
from .graph_partition_book import HeteroDataName, parse_hetero_data_name from ..ndarray import exist_shared_mem_array
from .graph_partition_book import NodePartitionPolicy, EdgePartitionPolicy from ..transforms import compact_graphs
from .graph_partition_book import _etype_str_to_tuple from . import graph_services, role, rpc
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
from . import rpc
from . import role
from .server_state import ServerState
from .rpc_server import start_server
from . import graph_services
from .graph_services import find_edges as dist_find_edges
from .graph_services import out_degrees as dist_out_degrees
from .graph_services import in_degrees as dist_in_degrees
from .dist_tensor import DistTensor from .dist_tensor import DistTensor
from .partition import RESERVED_FIELD_DTYPE from .graph_partition_book import (
_etype_str_to_tuple,
EdgePartitionPolicy,
get_shared_mem_partition_book,
HeteroDataName,
NodePartitionPolicy,
parse_hetero_data_name,
PartitionPolicy,
)
from .graph_services import (
find_edges as dist_find_edges,
in_degrees as dist_in_degrees,
out_degrees as dist_out_degrees,
)
from .kvstore import get_kvstore, KVServer
from .partition import (
load_partition,
load_partition_book,
load_partition_feats,
RESERVED_FIELD_DTYPE,
)
from .rpc_server import start_server
from .server_state import ServerState
from .shared_mem_utils import (
_get_edata_path,
_get_ndata_path,
_to_shared_mem,
DTYPE_DICT,
)
INIT_GRAPH = 800001 INIT_GRAPH = 800001
class InitGraphRequest(rpc.Request): class InitGraphRequest(rpc.Request):
""" Init graph on the backup servers. """Init graph on the backup servers.
When the backup server starts, they don't load the graph structure. When the backup server starts, they don't load the graph structure.
This request tells the backup servers that they can map to the graph structure This request tells the backup servers that they can map to the graph structure
with shared memory. with shared memory.
""" """
def __init__(self, graph_name): def __init__(self, graph_name):
self._graph_name = graph_name self._graph_name = graph_name
...@@ -58,9 +74,10 @@ class InitGraphRequest(rpc.Request): ...@@ -58,9 +74,10 @@ class InitGraphRequest(rpc.Request):
server_state.graph = _get_graph_from_shared_mem(self._graph_name) server_state.graph = _get_graph_from_shared_mem(self._graph_name)
return InitGraphResponse(self._graph_name) return InitGraphResponse(self._graph_name)
class InitGraphResponse(rpc.Response): class InitGraphResponse(rpc.Response):
""" Ack the init graph request """Ack the init graph request"""
"""
def __init__(self, graph_name): def __init__(self, graph_name):
self._graph_name = graph_name self._graph_name = graph_name
...@@ -70,17 +87,24 @@ class InitGraphResponse(rpc.Response): ...@@ -70,17 +87,24 @@ class InitGraphResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self._graph_name = state self._graph_name = state
def _copy_graph_to_shared_mem(g, graph_name, graph_format): def _copy_graph_to_shared_mem(g, graph_name, graph_format):
new_g = g.shared_memory(graph_name, formats=graph_format) new_g = g.shared_memory(graph_name, formats=graph_format)
# We should share the node/edge data to the client explicitly instead of putting them # We should share the node/edge data to the client explicitly instead of putting them
# in the KVStore because some of the node/edge data may be duplicated. # in the KVStore because some of the node/edge data may be duplicated.
new_g.ndata['inner_node'] = _to_shared_mem(g.ndata['inner_node'], new_g.ndata["inner_node"] = _to_shared_mem(
_get_ndata_path(graph_name, 'inner_node')) g.ndata["inner_node"], _get_ndata_path(graph_name, "inner_node")
new_g.ndata[NID] = _to_shared_mem(g.ndata[NID], _get_ndata_path(graph_name, NID)) )
new_g.ndata[NID] = _to_shared_mem(
g.ndata[NID], _get_ndata_path(graph_name, NID)
)
new_g.edata['inner_edge'] = _to_shared_mem(g.edata['inner_edge'], new_g.edata["inner_edge"] = _to_shared_mem(
_get_edata_path(graph_name, 'inner_edge')) g.edata["inner_edge"], _get_edata_path(graph_name, "inner_edge")
new_g.edata[EID] = _to_shared_mem(g.edata[EID], _get_edata_path(graph_name, EID)) )
new_g.edata[EID] = _to_shared_mem(
g.edata[EID], _get_edata_path(graph_name, EID)
)
# for heterogeneous graph, we need to put ETYPE into KVStore # for heterogeneous graph, we need to put ETYPE into KVStore
# for homogeneous graph, ETYPE does not exist # for homogeneous graph, ETYPE does not exist
if ETYPE in g.edata: if ETYPE in g.edata:
...@@ -90,51 +114,61 @@ def _copy_graph_to_shared_mem(g, graph_name, graph_format): ...@@ -90,51 +114,61 @@ def _copy_graph_to_shared_mem(g, graph_name, graph_format):
) )
return new_g return new_g
def _get_shared_mem_ndata(g, graph_name, name): def _get_shared_mem_ndata(g, graph_name, name):
''' Get shared-memory node data from DistGraph server. """Get shared-memory node data from DistGraph server.
This is called by the DistGraph client to access the node data in the DistGraph server This is called by the DistGraph client to access the node data in the DistGraph server
with shared memory. with shared memory.
''' """
shape = (g.number_of_nodes(),) shape = (g.number_of_nodes(),)
dtype = RESERVED_FIELD_DTYPE[name] dtype = RESERVED_FIELD_DTYPE[name]
dtype = DTYPE_DICT[dtype] dtype = DTYPE_DICT[dtype]
data = empty_shared_mem(_get_ndata_path(graph_name, name), False, shape, dtype) data = empty_shared_mem(
_get_ndata_path(graph_name, name), False, shape, dtype
)
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
return F.zerocopy_from_dlpack(dlpack) return F.zerocopy_from_dlpack(dlpack)
def _get_shared_mem_edata(g, graph_name, name): def _get_shared_mem_edata(g, graph_name, name):
''' Get shared-memory edge data from DistGraph server. """Get shared-memory edge data from DistGraph server.
This is called by the DistGraph client to access the edge data in the DistGraph server This is called by the DistGraph client to access the edge data in the DistGraph server
with shared memory. with shared memory.
''' """
shape = (g.number_of_edges(),) shape = (g.number_of_edges(),)
dtype = RESERVED_FIELD_DTYPE[name] dtype = RESERVED_FIELD_DTYPE[name]
dtype = DTYPE_DICT[dtype] dtype = DTYPE_DICT[dtype]
data = empty_shared_mem(_get_edata_path(graph_name, name), False, shape, dtype) data = empty_shared_mem(
_get_edata_path(graph_name, name), False, shape, dtype
)
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
return F.zerocopy_from_dlpack(dlpack) return F.zerocopy_from_dlpack(dlpack)
def _exist_shared_mem_array(graph_name, name): def _exist_shared_mem_array(graph_name, name):
return exist_shared_mem_array(_get_edata_path(graph_name, name)) return exist_shared_mem_array(_get_edata_path(graph_name, name))
def _get_graph_from_shared_mem(graph_name): def _get_graph_from_shared_mem(graph_name):
''' Get the graph from the DistGraph server. """Get the graph from the DistGraph server.
The DistGraph server puts the graph structure of the local partition in the shared memory. The DistGraph server puts the graph structure of the local partition in the shared memory.
The client can access the graph structure and some metadata on nodes and edges directly The client can access the graph structure and some metadata on nodes and edges directly
through shared memory to reduce the overhead of data access. through shared memory to reduce the overhead of data access.
''' """
g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(graph_name) g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(
graph_name
)
if g is None: if g is None:
return None return None
g = DGLGraph(g, ntypes, etypes) g = DGLGraph(g, ntypes, etypes)
g.ndata['inner_node'] = _get_shared_mem_ndata(g, graph_name, 'inner_node') g.ndata["inner_node"] = _get_shared_mem_ndata(g, graph_name, "inner_node")
g.ndata[NID] = _get_shared_mem_ndata(g, graph_name, NID) g.ndata[NID] = _get_shared_mem_ndata(g, graph_name, NID)
g.edata['inner_edge'] = _get_shared_mem_edata(g, graph_name, 'inner_edge') g.edata["inner_edge"] = _get_shared_mem_edata(g, graph_name, "inner_edge")
g.edata[EID] = _get_shared_mem_edata(g, graph_name, EID) g.edata[EID] = _get_shared_mem_edata(g, graph_name, EID)
# heterogeneous graph has ETYPE # heterogeneous graph has ETYPE
...@@ -142,12 +176,15 @@ def _get_graph_from_shared_mem(graph_name): ...@@ -142,12 +176,15 @@ def _get_graph_from_shared_mem(graph_name):
g.edata[ETYPE] = _get_shared_mem_edata(g, graph_name, ETYPE) g.edata[ETYPE] = _get_shared_mem_edata(g, graph_name, ETYPE)
return g return g
NodeSpace = namedtuple('NodeSpace', ['data'])
EdgeSpace = namedtuple('EdgeSpace', ['data']) NodeSpace = namedtuple("NodeSpace", ["data"])
EdgeSpace = namedtuple("EdgeSpace", ["data"])
class HeteroNodeView(object): class HeteroNodeView(object):
"""A NodeView class to act as G.nodes for a DistGraph.""" """A NodeView class to act as G.nodes for a DistGraph."""
__slots__ = ['_graph']
__slots__ = ["_graph"]
def __init__(self, graph): def __init__(self, graph):
self._graph = graph self._graph = graph
...@@ -156,9 +193,11 @@ class HeteroNodeView(object): ...@@ -156,9 +193,11 @@ class HeteroNodeView(object):
assert isinstance(key, str) assert isinstance(key, str)
return NodeSpace(data=NodeDataView(self._graph, key)) return NodeSpace(data=NodeDataView(self._graph, key))
class HeteroEdgeView(object): class HeteroEdgeView(object):
"""An EdgeView class to act as G.edges for a DistGraph.""" """An EdgeView class to act as G.edges for a DistGraph."""
__slots__ = ['_graph']
__slots__ = ["_graph"]
def __init__(self, graph): def __init__(self, graph):
self._graph = graph self._graph = graph
...@@ -169,10 +208,11 @@ class HeteroEdgeView(object): ...@@ -169,10 +208,11 @@ class HeteroEdgeView(object):
), f"Expect edge type in string or triplet of string, but got {key}." ), f"Expect edge type in string or triplet of string, but got {key}."
return EdgeSpace(data=EdgeDataView(self._graph, key)) return EdgeSpace(data=EdgeDataView(self._graph, key))
class NodeDataView(MutableMapping): class NodeDataView(MutableMapping):
"""The data view class when dist_graph.ndata[...].data is called. """The data view class when dist_graph.ndata[...].data is called."""
"""
__slots__ = ['_graph', '_data'] __slots__ = ["_graph", "_data"]
def __init__(self, g, ntype=None): def __init__(self, g, ntype=None):
self._graph = g self._graph = g
...@@ -208,13 +248,16 @@ class NodeDataView(MutableMapping): ...@@ -208,13 +248,16 @@ class NodeDataView(MutableMapping):
for name in self._data: for name in self._data:
dtype = F.dtype(self._data[name]) dtype = F.dtype(self._data[name])
shape = F.shape(self._data[name]) shape = F.shape(self._data[name])
reprs[name] = 'DistTensor(shape={}, dtype={})'.format(str(shape), str(dtype)) reprs[name] = "DistTensor(shape={}, dtype={})".format(
str(shape), str(dtype)
)
return repr(reprs) return repr(reprs)
class EdgeDataView(MutableMapping): class EdgeDataView(MutableMapping):
"""The data view class when G.edges[...].data is called. """The data view class when G.edges[...].data is called."""
"""
__slots__ = ['_graph', '_data'] __slots__ = ["_graph", "_data"]
def __init__(self, g, etype=None): def __init__(self, g, etype=None):
self._graph = g self._graph = g
...@@ -249,12 +292,14 @@ class EdgeDataView(MutableMapping): ...@@ -249,12 +292,14 @@ class EdgeDataView(MutableMapping):
for name in self._data: for name in self._data:
dtype = F.dtype(self._data[name]) dtype = F.dtype(self._data[name])
shape = F.shape(self._data[name]) shape = F.shape(self._data[name])
reprs[name] = 'DistTensor(shape={}, dtype={})'.format(str(shape), str(dtype)) reprs[name] = "DistTensor(shape={}, dtype={})".format(
str(shape), str(dtype)
)
return repr(reprs) return repr(reprs)
class DistGraphServer(KVServer): class DistGraphServer(KVServer):
''' The DistGraph server. """The DistGraph server.
This DistGraph server loads the graph data and sets up a service so that trainers and This DistGraph server loads the graph data and sets up a service so that trainers and
samplers can read data of a graph partition (graph structure, node data and edge data) samplers can read data of a graph partition (graph structure, node data and edge data)
...@@ -289,15 +334,26 @@ class DistGraphServer(KVServer): ...@@ -289,15 +334,26 @@ class DistGraphServer(KVServer):
Whether to keep server alive when clients exit Whether to keep server alive when clients exit
net_type : str net_type : str
Backend rpc type: ``'socket'`` or ``'tensorpipe'`` Backend rpc type: ``'socket'`` or ``'tensorpipe'``
''' """
def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False, def __init__(
graph_format=('csc', 'coo'), keep_alive=False, self,
net_type='socket'): server_id,
super(DistGraphServer, self).__init__(server_id=server_id, ip_config,
num_servers,
num_clients,
part_config,
disable_shared_mem=False,
graph_format=("csc", "coo"),
keep_alive=False,
net_type="socket",
):
super(DistGraphServer, self).__init__(
server_id=server_id,
ip_config=ip_config, ip_config=ip_config,
num_servers=num_servers, num_servers=num_servers,
num_clients=num_clients) num_clients=num_clients,
)
self.ip_config = ip_config self.ip_config = ip_config
self.num_servers = num_servers self.num_servers = num_servers
self.keep_alive = keep_alive self.keep_alive = keep_alive
...@@ -305,13 +361,22 @@ class DistGraphServer(KVServer): ...@@ -305,13 +361,22 @@ class DistGraphServer(KVServer):
# Load graph partition data. # Load graph partition data.
if self.is_backup_server(): if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards. # The backup server doesn't load the graph partition. It'll initialized afterwards.
self.gpb, graph_name, ntypes, etypes = load_partition_book(part_config, self.part_id) self.gpb, graph_name, ntypes, etypes = load_partition_book(
part_config, self.part_id
)
self.client_g = None self.client_g = None
else: else:
# Loading of node/edge_feats are deferred to lower the peak memory consumption. # Loading of node/edge_feats are deferred to lower the peak memory consumption.
self.client_g, _, _, self.gpb, graph_name, \ (
ntypes, etypes = load_partition(part_config, self.part_id, load_feats=False) self.client_g,
print('load ' + graph_name) _,
_,
self.gpb,
graph_name,
ntypes,
etypes,
) = load_partition(part_config, self.part_id, load_feats=False)
print("load " + graph_name)
# formatting dtype # formatting dtype
# TODO(Rui) Formatting forcely is not a perfect solution. # TODO(Rui) Formatting forcely is not a perfect solution.
# We'd better store all dtypes when mapping to shared memory # We'd better store all dtypes when mapping to shared memory
...@@ -319,72 +384,97 @@ class DistGraphServer(KVServer): ...@@ -319,72 +384,97 @@ class DistGraphServer(KVServer):
for k, dtype in RESERVED_FIELD_DTYPE.items(): for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in self.client_g.ndata: if k in self.client_g.ndata:
self.client_g.ndata[k] = F.astype( self.client_g.ndata[k] = F.astype(
self.client_g.ndata[k], dtype) self.client_g.ndata[k], dtype
)
if k in self.client_g.edata: if k in self.client_g.edata:
self.client_g.edata[k] = F.astype( self.client_g.edata[k] = F.astype(
self.client_g.edata[k], dtype) self.client_g.edata[k], dtype
)
# Create the graph formats specified the users. # Create the graph formats specified the users.
self.client_g = self.client_g.formats(graph_format) self.client_g = self.client_g.formats(graph_format)
self.client_g.create_formats_() self.client_g.create_formats_()
if not disable_shared_mem: if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name, graph_format) self.client_g = _copy_graph_to_shared_mem(
self.client_g, graph_name, graph_format
)
if not disable_shared_mem: if not disable_shared_mem:
self.gpb.shared_memory(graph_name) self.gpb.shared_memory(graph_name)
assert self.gpb.partid == self.part_id assert self.gpb.partid == self.part_id
for ntype in ntypes: for ntype in ntypes:
node_name = HeteroDataName(True, ntype, "") node_name = HeteroDataName(True, ntype, "")
self.add_part_policy(PartitionPolicy(node_name.policy_str, self.gpb)) self.add_part_policy(
PartitionPolicy(node_name.policy_str, self.gpb)
)
for etype in etypes: for etype in etypes:
edge_name = HeteroDataName(False, etype, "") edge_name = HeteroDataName(False, etype, "")
self.add_part_policy(PartitionPolicy(edge_name.policy_str, self.gpb)) self.add_part_policy(
PartitionPolicy(edge_name.policy_str, self.gpb)
)
if not self.is_backup_server(): if not self.is_backup_server():
node_feats, _ = load_partition_feats(part_config, self.part_id, node_feats, _ = load_partition_feats(
load_nodes=True, load_edges=False) part_config, self.part_id, load_nodes=True, load_edges=False
)
for name in node_feats: for name in node_feats:
# The feature name has the following format: node_type + "/" + feature_name to avoid # The feature name has the following format: node_type + "/" + feature_name to avoid
# feature name collision for different node types. # feature name collision for different node types.
ntype, feat_name = name.split('/') ntype, feat_name = name.split("/")
data_name = HeteroDataName(True, ntype, feat_name) data_name = HeteroDataName(True, ntype, feat_name)
self.init_data(name=str(data_name), policy_str=data_name.policy_str, self.init_data(
data_tensor=node_feats[name]) name=str(data_name),
policy_str=data_name.policy_str,
data_tensor=node_feats[name],
)
self.orig_data.add(str(data_name)) self.orig_data.add(str(data_name))
# Let's free once node features are copied to shared memory # Let's free once node features are copied to shared memory
del node_feats del node_feats
gc.collect() gc.collect()
_, edge_feats = load_partition_feats(part_config, self.part_id, _, edge_feats = load_partition_feats(
load_nodes=False, load_edges=True) part_config, self.part_id, load_nodes=False, load_edges=True
)
for name in edge_feats: for name in edge_feats:
# The feature name has the following format: edge_type + "/" + feature_name to avoid # The feature name has the following format: edge_type + "/" + feature_name to avoid
# feature name collision for different edge types. # feature name collision for different edge types.
etype, feat_name = name.split('/') etype, feat_name = name.split("/")
etype = _etype_str_to_tuple(etype) etype = _etype_str_to_tuple(etype)
data_name = HeteroDataName(False, etype, feat_name) data_name = HeteroDataName(False, etype, feat_name)
self.init_data(name=str(data_name), policy_str=data_name.policy_str, self.init_data(
data_tensor=edge_feats[name]) name=str(data_name),
policy_str=data_name.policy_str,
data_tensor=edge_feats[name],
)
self.orig_data.add(str(data_name)) self.orig_data.add(str(data_name))
# Let's free once edge features are copied to shared memory # Let's free once edge features are copied to shared memory
del edge_feats del edge_feats
gc.collect() gc.collect()
def start(self): def start(self):
""" Start graph store server. """Start graph store server."""
"""
# start server # start server
server_state = ServerState(kv_store=self, local_g=self.client_g, server_state = ServerState(
partition_book=self.gpb, keep_alive=self.keep_alive) kv_store=self,
print('start graph service on server {} for part {}'.format( local_g=self.client_g,
self.server_id, self.part_id)) partition_book=self.gpb,
start_server(server_id=self.server_id, keep_alive=self.keep_alive,
)
print(
"start graph service on server {} for part {}".format(
self.server_id, self.part_id
)
)
start_server(
server_id=self.server_id,
ip_config=self.ip_config, ip_config=self.ip_config,
num_servers=self.num_servers, num_servers=self.num_servers,
num_clients=self.num_clients, num_clients=self.num_clients,
server_state=server_state, server_state=server_state,
net_type=self.net_type) net_type=self.net_type,
)
class DistGraph: class DistGraph:
'''The class for accessing a distributed graph. """The class for accessing a distributed graph.
This class provides a subset of DGLGraph APIs for accessing partitioned graph data in This class provides a subset of DGLGraph APIs for accessing partitioned graph data in
distributed GNN training and inference. Thus, its main use case is to work with distributed GNN training and inference. Thus, its main use case is to work with
...@@ -455,35 +545,45 @@ class DistGraph: ...@@ -455,35 +545,45 @@ class DistGraph:
DGL's distributed training by default runs server processes and trainer processes on the same DGL's distributed training by default runs server processes and trainer processes on the same
set of machines. If users need to run them on different sets of machines, it requires set of machines. If users need to run them on different sets of machines, it requires
manually setting up servers and trainers. The setup is not fully tested yet. manually setting up servers and trainers. The setup is not fully tested yet.
''' """
def __init__(self, graph_name, gpb=None, part_config=None): def __init__(self, graph_name, gpb=None, part_config=None):
self.graph_name = graph_name self.graph_name = graph_name
if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone': if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone":
assert part_config is not None, \ assert (
'When running in the standalone model, the partition config file is required' part_config is not None
), "When running in the standalone model, the partition config file is required"
self._client = get_kvstore() self._client = get_kvstore()
assert self._client is not None, \ assert (
'Distributed module is not initialized. Please call dgl.distributed.initialize.' self._client is not None
), "Distributed module is not initialized. Please call dgl.distributed.initialize."
# Load graph partition data. # Load graph partition data.
g, node_feats, edge_feats, self._gpb, _, _, _ = load_partition(part_config, 0) g, node_feats, edge_feats, self._gpb, _, _, _ = load_partition(
assert self._gpb.num_partitions() == 1, \ part_config, 0
'The standalone mode can only work with the graph data with one partition' )
assert (
self._gpb.num_partitions() == 1
), "The standalone mode can only work with the graph data with one partition"
if self._gpb is None: if self._gpb is None:
self._gpb = gpb self._gpb = gpb
self._g = g self._g = g
for name in node_feats: for name in node_feats:
# The feature name has the following format: node_type + "/" + feature_name. # The feature name has the following format: node_type + "/" + feature_name.
ntype, feat_name = name.split('/') ntype, feat_name = name.split("/")
self._client.add_data(str(HeteroDataName(True, ntype, feat_name)), self._client.add_data(
str(HeteroDataName(True, ntype, feat_name)),
node_feats[name], node_feats[name],
NodePartitionPolicy(self._gpb, ntype=ntype)) NodePartitionPolicy(self._gpb, ntype=ntype),
)
for name in edge_feats: for name in edge_feats:
# The feature name has the following format: edge_type + "/" + feature_name. # The feature name has the following format: edge_type + "/" + feature_name.
etype, feat_name = name.split('/') etype, feat_name = name.split("/")
etype = _etype_str_to_tuple(etype) etype = _etype_str_to_tuple(etype)
self._client.add_data(str(HeteroDataName(False, etype, feat_name)), self._client.add_data(
str(HeteroDataName(False, etype, feat_name)),
edge_feats[name], edge_feats[name],
EdgePartitionPolicy(self._gpb, etype=etype)) EdgePartitionPolicy(self._gpb, etype=etype),
)
self._client.map_shared_data(self._gpb) self._client.map_shared_data(self._gpb)
rpc.set_num_client(1) rpc.set_num_client(1)
else: else:
...@@ -501,17 +601,20 @@ class DistGraph: ...@@ -501,17 +601,20 @@ class DistGraph:
self._num_nodes = 0 self._num_nodes = 0
self._num_edges = 0 self._num_edges = 0
for part_md in self._gpb.metadata(): for part_md in self._gpb.metadata():
self._num_nodes += int(part_md['num_nodes']) self._num_nodes += int(part_md["num_nodes"])
self._num_edges += int(part_md['num_edges']) self._num_edges += int(part_md["num_edges"])
# When we store node/edge types in a list, they are stored in the order of type IDs. # When we store node/edge types in a list, they are stored in the order of type IDs.
self._ntype_map = {ntype:i for i, ntype in enumerate(self.ntypes)} self._ntype_map = {ntype: i for i, ntype in enumerate(self.ntypes)}
self._etype_map = {etype:i for i, etype in enumerate(self.canonical_etypes)} self._etype_map = {
etype: i for i, etype in enumerate(self.canonical_etypes)
}
def _init(self, gpb): def _init(self, gpb):
self._client = get_kvstore() self._client = get_kvstore()
assert self._client is not None, \ assert (
'Distributed module is not initialized. Please call dgl.distributed.initialize.' self._client is not None
), "Distributed module is not initialized. Please call dgl.distributed.initialize."
self._g = _get_graph_from_shared_mem(self.graph_name) self._g = _get_graph_from_shared_mem(self.graph_name)
self._gpb = get_shared_mem_partition_book(self.graph_name) self._gpb = get_shared_mem_partition_book(self.graph_name)
if self._gpb is None: if self._gpb is None:
...@@ -519,20 +622,24 @@ class DistGraph: ...@@ -519,20 +622,24 @@ class DistGraph:
self._client.map_shared_data(self._gpb) self._client.map_shared_data(self._gpb)
def _init_ndata_store(self): def _init_ndata_store(self):
'''Initialize node data store.''' """Initialize node data store."""
self._ndata_store = {} self._ndata_store = {}
for ntype in self.ntypes: for ntype in self.ntypes:
names = self._get_ndata_names(ntype) names = self._get_ndata_names(ntype)
data = {} data = {}
for name in names: for name in names:
assert name.is_node() assert name.is_node()
policy = PartitionPolicy(name.policy_str, policy = PartitionPolicy(
self.get_partition_book() name.policy_str, self.get_partition_book()
) )
dtype, shape, _ = self._client.get_data_meta(str(name)) dtype, shape, _ = self._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore. # We create a wrapper on the existing tensor in the kvstore.
data[name.get_name()] = DistTensor(shape, dtype, data[name.get_name()] = DistTensor(
name.get_name(), part_policy=policy, attach=False shape,
dtype,
name.get_name(),
part_policy=policy,
attach=False,
) )
if len(self.ntypes) == 1: if len(self.ntypes) == 1:
self._ndata_store = data self._ndata_store = data
...@@ -540,20 +647,24 @@ class DistGraph: ...@@ -540,20 +647,24 @@ class DistGraph:
self._ndata_store[ntype] = data self._ndata_store[ntype] = data
def _init_edata_store(self): def _init_edata_store(self):
'''Initialize edge data store.''' """Initialize edge data store."""
self._edata_store = {} self._edata_store = {}
for etype in self.canonical_etypes: for etype in self.canonical_etypes:
names = self._get_edata_names(etype) names = self._get_edata_names(etype)
data = {} data = {}
for name in names: for name in names:
assert name.is_edge() assert name.is_edge()
policy = PartitionPolicy(name.policy_str, policy = PartitionPolicy(
self.get_partition_book() name.policy_str, self.get_partition_book()
) )
dtype, shape, _ = self._client.get_data_meta(str(name)) dtype, shape, _ = self._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore. # We create a wrapper on the existing tensor in the kvstore.
data[name.get_name()] = DistTensor(shape, dtype, data[name.get_name()] = DistTensor(
name.get_name(), part_policy=policy, attach=False shape,
dtype,
name.get_name(),
part_policy=policy,
attach=False,
) )
if len(self.canonical_etypes) == 1: if len(self.canonical_etypes) == 1:
self._edata_store = data self._edata_store = data
...@@ -572,12 +683,12 @@ class DistGraph: ...@@ -572,12 +683,12 @@ class DistGraph:
self._num_nodes = 0 self._num_nodes = 0
self._num_edges = 0 self._num_edges = 0
for part_md in self._gpb.metadata(): for part_md in self._gpb.metadata():
self._num_nodes += int(part_md['num_nodes']) self._num_nodes += int(part_md["num_nodes"])
self._num_edges += int(part_md['num_edges']) self._num_edges += int(part_md["num_edges"])
@property @property
def local_partition(self): def local_partition(self):
''' Return the local partition on the client """Return the local partition on the client
DistGraph provides a global view of the distributed graph. Internally, DistGraph provides a global view of the distributed graph. Internally,
it may contains a partition of the graph if it is co-located with it may contains a partition of the graph if it is co-located with
...@@ -588,19 +699,17 @@ class DistGraph: ...@@ -588,19 +699,17 @@ class DistGraph:
------- -------
DGLGraph DGLGraph
The local partition The local partition
''' """
return self._g return self._g
@property @property
def nodes(self): def nodes(self):
'''Return a node view """Return a node view"""
'''
return HeteroNodeView(self) return HeteroNodeView(self)
@property @property
def edges(self): def edges(self):
'''Return an edge view """Return an edge view"""
'''
return HeteroEdgeView(self) return HeteroEdgeView(self)
@property @property
...@@ -612,7 +721,9 @@ class DistGraph: ...@@ -612,7 +721,9 @@ class DistGraph:
NodeDataView NodeDataView
The data view in the distributed graph storage. The data view in the distributed graph storage.
""" """
assert len(self.ntypes) == 1, "ndata only works for a graph with one node type." assert (
len(self.ntypes) == 1
), "ndata only works for a graph with one node type."
return NodeDataView(self) return NodeDataView(self)
@property @property
...@@ -624,7 +735,9 @@ class DistGraph: ...@@ -624,7 +735,9 @@ class DistGraph:
EdgeDataView EdgeDataView
The data view in the distributed graph storage. The data view in the distributed graph storage.
""" """
assert len(self.etypes) == 1, "edata only works for a graph with one edge type." assert (
len(self.etypes) == 1
), "edata only works for a graph with one edge type."
return EdgeDataView(self) return EdgeDataView(self)
@property @property
...@@ -811,8 +924,10 @@ class DistGraph: ...@@ -811,8 +924,10 @@ class DistGraph:
""" """
if ntype is None: if ntype is None:
if len(self._ntype_map) != 1: if len(self._ntype_map) != 1:
raise DGLError('Node type name must be specified if there are more than one ' raise DGLError(
'node types.') "Node type name must be specified if there are more than one "
"node types."
)
return 0 return 0
return self._ntype_map[ntype] return self._ntype_map[ntype]
...@@ -833,8 +948,10 @@ class DistGraph: ...@@ -833,8 +948,10 @@ class DistGraph:
""" """
if etype is None: if etype is None:
if len(self._etype_map) != 1: if len(self._etype_map) != 1:
raise DGLError('Edge type name must be specified if there are more than one ' raise DGLError(
'edge types.') "Edge type name must be specified if there are more than one "
"edge types."
)
return 0 return 0
etype = self.to_canonical_etype(etype) etype = self.to_canonical_etype(etype)
return self._etype_map[etype] return self._etype_map[etype]
...@@ -871,7 +988,9 @@ class DistGraph: ...@@ -871,7 +988,9 @@ class DistGraph:
if len(self.ntypes) == 1: if len(self.ntypes) == 1:
return self._gpb._num_nodes(self.ntypes[0]) return self._gpb._num_nodes(self.ntypes[0])
else: else:
return sum([self._gpb._num_nodes(ntype) for ntype in self.ntypes]) return sum(
[self._gpb._num_nodes(ntype) for ntype in self.ntypes]
)
return self._gpb._num_nodes(ntype) return self._gpb._num_nodes(ntype)
def num_edges(self, etype=None): def num_edges(self, etype=None):
...@@ -901,8 +1020,12 @@ class DistGraph: ...@@ -901,8 +1020,12 @@ class DistGraph:
123718280 123718280
""" """
if etype is None: if etype is None:
return sum([self._gpb._num_edges(c_etype) return sum(
for c_etype in self.canonical_etypes]) [
self._gpb._num_edges(c_etype)
for c_etype in self.canonical_etypes
]
)
return self._gpb._num_edges(etype) return self._gpb._num_edges(etype)
def out_degrees(self, u=ALL): def out_degrees(self, u=ALL):
...@@ -1058,7 +1181,7 @@ class DistGraph: ...@@ -1058,7 +1181,7 @@ class DistGraph:
return schemes return schemes
def rank(self): def rank(self):
''' The rank of the current DistGraph. """The rank of the current DistGraph.
This returns a unique number to identify the DistGraph object among all of This returns a unique number to identify the DistGraph object among all of
the client processes. the client processes.
...@@ -1067,11 +1190,11 @@ class DistGraph: ...@@ -1067,11 +1190,11 @@ class DistGraph:
------- -------
int int
The rank of the current DistGraph. The rank of the current DistGraph.
''' """
return role.get_global_rank() return role.get_global_rank()
def find_edges(self, edges, etype=None): def find_edges(self, edges, etype=None):
""" Given an edge ID array, return the source """Given an edge ID array, return the source
and destination node ID array ``s`` and ``d``. ``s[i]`` and ``d[i]`` and destination node ID array ``s`` and ``d``. ``s[i]`` and ``d[i]``
are source and destination node ID for edge ``eid[i]``. are source and destination node ID for edge ``eid[i]``.
...@@ -1098,7 +1221,9 @@ class DistGraph: ...@@ -1098,7 +1221,9 @@ class DistGraph:
The destination node ID array. The destination node ID array.
""" """
if etype is None: if etype is None:
assert len(self.etypes) == 1, 'find_edges requires etype for heterogeneous graphs.' assert (
len(self.etypes) == 1
), "find_edges requires etype for heterogeneous graphs."
gpb = self.get_partition_book() gpb = self.get_partition_book()
if len(gpb.etypes) > 1: if len(gpb.etypes) > 1:
...@@ -1151,7 +1276,9 @@ class DistGraph: ...@@ -1151,7 +1276,9 @@ class DistGraph:
for etype, edge in edges.items(): for etype, edge in edges.items():
etype = self.to_canonical_etype(etype) etype = self.to_canonical_etype(etype)
subg[etype] = self.find_edges(edge, etype) subg[etype] = self.find_edges(edge, etype)
num_nodes = {ntype: self.number_of_nodes(ntype) for ntype in self.ntypes} num_nodes = {
ntype: self.number_of_nodes(ntype) for ntype in self.ntypes
}
subg = dgl_heterograph(subg, num_nodes_dict=num_nodes) subg = dgl_heterograph(subg, num_nodes_dict=num_nodes)
for etype in edges: for etype in edges:
subg.edges[etype].data[EID] = edges[etype] subg.edges[etype].data[EID] = edges[etype]
...@@ -1163,7 +1290,7 @@ class DistGraph: ...@@ -1163,7 +1290,7 @@ class DistGraph:
if relabel_nodes: if relabel_nodes:
subg = compact_graphs(subg) subg = compact_graphs(subg)
assert store_ids, 'edge_subgraph always stores original node/edge IDs.' assert store_ids, "edge_subgraph always stores original node/edge IDs."
return subg return subg
def get_partition_book(self): def get_partition_book(self):
...@@ -1220,55 +1347,72 @@ class DistGraph: ...@@ -1220,55 +1347,72 @@ class DistGraph:
return EdgePartitionPolicy(self.get_partition_book(), etype) return EdgePartitionPolicy(self.get_partition_book(), etype)
def barrier(self): def barrier(self):
'''Barrier for all client nodes. """Barrier for all client nodes.
This API blocks the current process untill all the clients invoke this API. This API blocks the current process untill all the clients invoke this API.
Please use this API with caution. Please use this API with caution.
''' """
self._client.barrier() self._client.barrier()
def sample_neighbors(self, seed_nodes, fanout, edge_dir='in', prob=None, def sample_neighbors(
exclude_edges=None, replace=False, etype_sorted=True, self,
output_device=None): seed_nodes,
fanout,
edge_dir="in",
prob=None,
exclude_edges=None,
replace=False,
etype_sorted=True,
output_device=None,
):
# pylint: disable=unused-argument # pylint: disable=unused-argument
"""Sample neighbors from a distributed graph.""" """Sample neighbors from a distributed graph."""
if len(self.etypes) > 1: if len(self.etypes) > 1:
frontier = graph_services.sample_etype_neighbors( frontier = graph_services.sample_etype_neighbors(
self, seed_nodes, fanout, replace=replace, self,
etype_sorted=etype_sorted, prob=prob) seed_nodes,
fanout,
replace=replace,
etype_sorted=etype_sorted,
prob=prob,
)
else: else:
frontier = graph_services.sample_neighbors( frontier = graph_services.sample_neighbors(
self, seed_nodes, fanout, replace=replace, prob=prob) self, seed_nodes, fanout, replace=replace, prob=prob
)
return frontier return frontier
def _get_ndata_names(self, ntype=None): def _get_ndata_names(self, ntype=None):
''' Get the names of all node data. """Get the names of all node data."""
'''
names = self._client.gdata_name_list() names = self._client.gdata_name_list()
ndata_names = [] ndata_names = []
for name in names: for name in names:
name = parse_hetero_data_name(name) name = parse_hetero_data_name(name)
right_type = (name.get_type() == ntype) if ntype is not None else True right_type = (
(name.get_type() == ntype) if ntype is not None else True
)
if name.is_node() and right_type: if name.is_node() and right_type:
ndata_names.append(name) ndata_names.append(name)
return ndata_names return ndata_names
def _get_edata_names(self, etype=None): def _get_edata_names(self, etype=None):
''' Get the names of all edge data. """Get the names of all edge data."""
'''
if etype is not None: if etype is not None:
etype = self.to_canonical_etype(etype) etype = self.to_canonical_etype(etype)
names = self._client.gdata_name_list() names = self._client.gdata_name_list()
edata_names = [] edata_names = []
for name in names: for name in names:
name = parse_hetero_data_name(name) name = parse_hetero_data_name(name)
right_type = (name.get_type() == etype) if etype is not None else True right_type = (
(name.get_type() == etype) if etype is not None else True
)
if name.is_edge() and right_type: if name.is_edge() and right_type:
edata_names.append(name) edata_names.append(name)
return edata_names return edata_names
def _get_overlap(mask_arr, ids): def _get_overlap(mask_arr, ids):
""" Select the IDs given a boolean mask array. """Select the IDs given a boolean mask array.
The boolean mask array indicates all of the IDs to be selected. We want to The boolean mask array indicates all of the IDs to be selected. We want to
find the overlap between the IDs selected by the boolean mask array and find the overlap between the IDs selected by the boolean mask array and
...@@ -1293,15 +1437,18 @@ def _get_overlap(mask_arr, ids): ...@@ -1293,15 +1437,18 @@ def _get_overlap(mask_arr, ids):
masks = F.gather_row(F.tensor(mask_arr), ids) masks = F.gather_row(F.tensor(mask_arr), ids)
return F.boolean_mask(ids, masks) return F.boolean_mask(ids, masks)
def _split_local(partition_book, rank, elements, local_eles): def _split_local(partition_book, rank, elements, local_eles):
''' Split the input element list with respect to data locality. """Split the input element list with respect to data locality."""
'''
num_clients = role.get_num_trainers() num_clients = role.get_num_trainers()
num_client_per_part = num_clients // partition_book.num_partitions() num_client_per_part = num_clients // partition_book.num_partitions()
if rank is None: if rank is None:
rank = role.get_trainer_rank() rank = role.get_trainer_rank()
assert rank < num_clients, \ assert (
'The input rank ({}) is incorrect. #Trainers: {}'.format(rank, num_clients) rank < num_clients
), "The input rank ({}) is incorrect. #Trainers: {}".format(
rank, num_clients
)
# all ranks of the clients in the same machine are in a contiguous range. # all ranks of the clients in the same machine are in a contiguous range.
client_id_in_part = rank % num_client_per_part client_id_in_part = rank % num_client_per_part
local_eles = _get_overlap(elements, local_eles) local_eles = _get_overlap(elements, local_eles)
...@@ -1310,22 +1457,25 @@ def _split_local(partition_book, rank, elements, local_eles): ...@@ -1310,22 +1457,25 @@ def _split_local(partition_book, rank, elements, local_eles):
size = len(local_eles) // num_client_per_part size = len(local_eles) // num_client_per_part
# if this isn't the last client in the partition. # if this isn't the last client in the partition.
if client_id_in_part + 1 < num_client_per_part: if client_id_in_part + 1 < num_client_per_part:
return local_eles[(size * client_id_in_part):(size * (client_id_in_part + 1))] return local_eles[
(size * client_id_in_part) : (size * (client_id_in_part + 1))
]
else: else:
return local_eles[(size * client_id_in_part):] return local_eles[(size * client_id_in_part) :]
def _even_offset(n, k): def _even_offset(n, k):
''' Split an array of length n into k segments and the difference of thier length is """Split an array of length n into k segments and the difference of thier length is
at most 1. Return the offset of each segment. at most 1. Return the offset of each segment.
''' """
eles_per_part = n // k eles_per_part = n // k
offset = np.array([0] + [eles_per_part] * k, dtype=int) offset = np.array([0] + [eles_per_part] * k, dtype=int)
offset[1 : n - eles_per_part * k + 1] += 1 offset[1 : n - eles_per_part * k + 1] += 1
return np.cumsum(offset) return np.cumsum(offset)
def _split_even_to_part(partition_book, elements): def _split_even_to_part(partition_book, elements):
''' Split the input element list evenly. """Split the input element list evenly."""
'''
# here we divide the element list as evenly as possible. If we use range partitioning, # here we divide the element list as evenly as possible. If we use range partitioning,
# the split results also respect the data locality. Range partitioning is the default # the split results also respect the data locality. Range partitioning is the default
# strategy. # strategy.
...@@ -1350,7 +1500,9 @@ def _split_even_to_part(partition_book, elements): ...@@ -1350,7 +1500,9 @@ def _split_even_to_part(partition_book, elements):
part_eles = None part_eles = None
# compute the nonzero tensor of each partition instead of whole tensor to save memory # compute the nonzero tensor of each partition instead of whole tensor to save memory
for idx in range(0, num_elements, block_size): for idx in range(0, num_elements, block_size):
nonzero_block = F.nonzero_1d(elements[idx:min(idx+block_size, num_elements)]) nonzero_block = F.nonzero_1d(
elements[idx : min(idx + block_size, num_elements)]
)
x = y x = y
y += len(nonzero_block) y += len(nonzero_block)
if y > left and x < right: if y > left and x < right:
...@@ -1366,6 +1518,7 @@ def _split_even_to_part(partition_book, elements): ...@@ -1366,6 +1518,7 @@ def _split_even_to_part(partition_book, elements):
return part_eles return part_eles
def _split_random_within_part(partition_book, rank, part_eles): def _split_random_within_part(partition_book, rank, part_eles):
# If there are more than one client in a partition, we need to randomly select a subset of # If there are more than one client in a partition, we need to randomly select a subset of
# elements in the partition for a client. We have to make sure that the set of elements # elements in the partition for a client. We have to make sure that the set of elements
...@@ -1377,8 +1530,11 @@ def _split_random_within_part(partition_book, rank, part_eles): ...@@ -1377,8 +1530,11 @@ def _split_random_within_part(partition_book, rank, part_eles):
return part_eles return part_eles
if rank is None: if rank is None:
rank = role.get_trainer_rank() rank = role.get_trainer_rank()
assert rank < num_clients, \ assert (
'The input rank ({}) is incorrect. #Trainers: {}'.format(rank, num_clients) rank < num_clients
), "The input rank ({}) is incorrect. #Trainers: {}".format(
rank, num_clients
)
client_id_in_part = rank % num_client_per_part client_id_in_part = rank % num_client_per_part
offset = _even_offset(len(part_eles), num_client_per_part) offset = _even_offset(len(part_eles), num_client_per_part)
...@@ -1387,12 +1543,20 @@ def _split_random_within_part(partition_book, rank, part_eles): ...@@ -1387,12 +1543,20 @@ def _split_random_within_part(partition_book, rank, part_eles):
# of elements. # of elements.
np.random.seed(partition_book.partid) np.random.seed(partition_book.partid)
rand_idx = np.random.permutation(len(part_eles)) rand_idx = np.random.permutation(len(part_eles))
rand_idx = rand_idx[offset[client_id_in_part] : offset[client_id_in_part + 1]] rand_idx = rand_idx[
offset[client_id_in_part] : offset[client_id_in_part + 1]
]
idx, _ = F.sort_1d(F.tensor(rand_idx)) idx, _ = F.sort_1d(F.tensor(rand_idx))
return F.gather_row(part_eles, idx) return F.gather_row(part_eles, idx)
def _split_by_trainer_id(partition_book, part_eles, trainer_id,
num_client_per_part, client_id_in_part): def _split_by_trainer_id(
partition_book,
part_eles,
trainer_id,
num_client_per_part,
client_id_in_part,
):
# TODO(zhengda): MXNet cannot deal with empty tensors, which makes the implementation # TODO(zhengda): MXNet cannot deal with empty tensors, which makes the implementation
# much more difficult. Let's just use numpy for the computation for now. We just # much more difficult. Let's just use numpy for the computation for now. We just
# perform operations on vectors. It shouldn't be too difficult. # perform operations on vectors. It shouldn't be too difficult.
...@@ -1400,16 +1564,23 @@ def _split_by_trainer_id(partition_book, part_eles, trainer_id, ...@@ -1400,16 +1564,23 @@ def _split_by_trainer_id(partition_book, part_eles, trainer_id,
part_eles = F.asnumpy(part_eles) part_eles = F.asnumpy(part_eles)
part_id = trainer_id // num_client_per_part part_id = trainer_id // num_client_per_part
trainer_id = trainer_id % num_client_per_part trainer_id = trainer_id % num_client_per_part
local_eles = part_eles[np.nonzero(part_id[part_eles] == partition_book.partid)[0]] local_eles = part_eles[
np.nonzero(part_id[part_eles] == partition_book.partid)[0]
]
# these are the Ids of the local elements in the partition. The Ids are global Ids. # these are the Ids of the local elements in the partition. The Ids are global Ids.
remote_eles = part_eles[np.nonzero(part_id[part_eles] != partition_book.partid)[0]] remote_eles = part_eles[
np.nonzero(part_id[part_eles] != partition_book.partid)[0]
]
# these are the Ids of the remote nodes in the partition. The Ids are global Ids. # these are the Ids of the remote nodes in the partition. The Ids are global Ids.
local_eles_idx = np.concatenate( local_eles_idx = np.concatenate(
[np.nonzero(trainer_id[local_eles] == i)[0] for i in range(num_client_per_part)], [
np.nonzero(trainer_id[local_eles] == i)[0]
for i in range(num_client_per_part)
],
# trainer_id[local_eles] is the trainer ids of local nodes in the partition and we # trainer_id[local_eles] is the trainer ids of local nodes in the partition and we
# pick out the indices where the node belongs to each trainer i respectively, and # pick out the indices where the node belongs to each trainer i respectively, and
# concatenate them. # concatenate them.
axis=0 axis=0,
) )
# `local_eles_idx` is used to sort `local_eles` according to `trainer_id`. It is a # `local_eles_idx` is used to sort `local_eles` according to `trainer_id`. It is a
# permutation of 0...(len(local_eles)-1) # permutation of 0...(len(local_eles)-1)
...@@ -1421,15 +1592,28 @@ def _split_by_trainer_id(partition_book, part_eles, trainer_id, ...@@ -1421,15 +1592,28 @@ def _split_by_trainer_id(partition_book, part_eles, trainer_id,
remote_offsets = _even_offset(len(remote_eles), num_client_per_part) remote_offsets = _even_offset(len(remote_eles), num_client_per_part)
client_local_eles = local_eles[ client_local_eles = local_eles[
local_offsets[client_id_in_part]:local_offsets[client_id_in_part + 1]] local_offsets[client_id_in_part] : local_offsets[client_id_in_part + 1]
]
client_remote_eles = remote_eles[ client_remote_eles = remote_eles[
remote_offsets[client_id_in_part]:remote_offsets[client_id_in_part + 1]] remote_offsets[client_id_in_part] : remote_offsets[
client_eles = np.concatenate([client_local_eles, client_remote_eles], axis=0) client_id_in_part + 1
]
]
client_eles = np.concatenate(
[client_local_eles, client_remote_eles], axis=0
)
return F.tensor(client_eles) return F.tensor(client_eles)
def node_split(nodes, partition_book=None, ntype='_N', rank=None, force_even=True,
node_trainer_ids=None): def node_split(
''' Split nodes and return a subset for the local rank. nodes,
partition_book=None,
ntype="_N",
rank=None,
force_even=True,
node_trainer_ids=None,
):
"""Split nodes and return a subset for the local rank.
This function splits the input nodes based on the partition book and This function splits the input nodes based on the partition book and
returns a subset of nodes for the local rank. This method is used for returns a subset of nodes for the local rank. This method is used for
...@@ -1469,28 +1653,32 @@ def node_split(nodes, partition_book=None, ntype='_N', rank=None, force_even=Tru ...@@ -1469,28 +1653,32 @@ def node_split(nodes, partition_book=None, ntype='_N', rank=None, force_even=Tru
------- -------
1D-tensor 1D-tensor
The vector of node IDs that belong to the rank. The vector of node IDs that belong to the rank.
''' """
if not isinstance(nodes, DistTensor): if not isinstance(nodes, DistTensor):
assert partition_book is not None, 'Regular tensor requires a partition book.' assert (
partition_book is not None
), "Regular tensor requires a partition book."
elif partition_book is None: elif partition_book is None:
partition_book = nodes.part_policy.partition_book partition_book = nodes.part_policy.partition_book
assert len(nodes) == partition_book._num_nodes(ntype), \ assert len(nodes) == partition_book._num_nodes(
'The length of boolean mask vector should be the number of nodes in the graph.' ntype
), "The length of boolean mask vector should be the number of nodes in the graph."
if rank is None: if rank is None:
rank = role.get_trainer_rank() rank = role.get_trainer_rank()
if force_even: if force_even:
num_clients = role.get_num_trainers() num_clients = role.get_num_trainers()
num_client_per_part = num_clients // partition_book.num_partitions() num_client_per_part = num_clients // partition_book.num_partitions()
assert num_clients % partition_book.num_partitions() == 0, \ assert (
'The total number of clients should be multiple of the number of partitions.' num_clients % partition_book.num_partitions() == 0
), "The total number of clients should be multiple of the number of partitions."
part_nid = _split_even_to_part(partition_book, nodes) part_nid = _split_even_to_part(partition_book, nodes)
if num_client_per_part == 1: if num_client_per_part == 1:
return part_nid return part_nid
elif node_trainer_ids is None: elif node_trainer_ids is None:
return _split_random_within_part(partition_book, rank, part_nid) return _split_random_within_part(partition_book, rank, part_nid)
else: else:
trainer_id = node_trainer_ids[0:len(node_trainer_ids)] trainer_id = node_trainer_ids[0 : len(node_trainer_ids)]
max_trainer_id = F.as_scalar(F.reduce_max(trainer_id)) + 1 max_trainer_id = F.as_scalar(F.reduce_max(trainer_id)) + 1
if max_trainer_id > num_clients: if max_trainer_id > num_clients:
...@@ -1498,19 +1686,33 @@ def node_split(nodes, partition_book=None, ntype='_N', rank=None, force_even=Tru ...@@ -1498,19 +1686,33 @@ def node_split(nodes, partition_book=None, ntype='_N', rank=None, force_even=Tru
# trainers is less than the `num_trainers_per_machine` previously assigned during # trainers is less than the `num_trainers_per_machine` previously assigned during
# partitioning. # partitioning.
assert max_trainer_id % num_clients == 0 assert max_trainer_id % num_clients == 0
trainer_id //= (max_trainer_id // num_clients) trainer_id //= max_trainer_id // num_clients
client_id_in_part = rank % num_client_per_part client_id_in_part = rank % num_client_per_part
return _split_by_trainer_id(partition_book, part_nid, trainer_id, return _split_by_trainer_id(
num_client_per_part, client_id_in_part) partition_book,
part_nid,
trainer_id,
num_client_per_part,
client_id_in_part,
)
else: else:
# Get all nodes that belong to the rank. # Get all nodes that belong to the rank.
local_nids = partition_book.partid2nids(partition_book.partid, ntype=ntype) local_nids = partition_book.partid2nids(
partition_book.partid, ntype=ntype
)
return _split_local(partition_book, rank, nodes, local_nids) return _split_local(partition_book, rank, nodes, local_nids)
def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=True,
edge_trainer_ids=None): def edge_split(
''' Split edges and return a subset for the local rank. edges,
partition_book=None,
etype="_E",
rank=None,
force_even=True,
edge_trainer_ids=None,
):
"""Split edges and return a subset for the local rank.
This function splits the input edges based on the partition book and This function splits the input edges based on the partition book and
returns a subset of edges for the local rank. This method is used for returns a subset of edges for the local rank. This method is used for
...@@ -1550,27 +1752,31 @@ def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=Tru ...@@ -1550,27 +1752,31 @@ def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=Tru
------- -------
1D-tensor 1D-tensor
The vector of edge IDs that belong to the rank. The vector of edge IDs that belong to the rank.
''' """
if not isinstance(edges, DistTensor): if not isinstance(edges, DistTensor):
assert partition_book is not None, 'Regular tensor requires a partition book.' assert (
partition_book is not None
), "Regular tensor requires a partition book."
elif partition_book is None: elif partition_book is None:
partition_book = edges.part_policy.partition_book partition_book = edges.part_policy.partition_book
assert len(edges) == partition_book._num_edges(etype), \ assert len(edges) == partition_book._num_edges(
'The length of boolean mask vector should be the number of edges in the graph.' etype
), "The length of boolean mask vector should be the number of edges in the graph."
if rank is None: if rank is None:
rank = role.get_trainer_rank() rank = role.get_trainer_rank()
if force_even: if force_even:
num_clients = role.get_num_trainers() num_clients = role.get_num_trainers()
num_client_per_part = num_clients // partition_book.num_partitions() num_client_per_part = num_clients // partition_book.num_partitions()
assert num_clients % partition_book.num_partitions() == 0, \ assert (
'The total number of clients should be multiple of the number of partitions.' num_clients % partition_book.num_partitions() == 0
), "The total number of clients should be multiple of the number of partitions."
part_eid = _split_even_to_part(partition_book, edges) part_eid = _split_even_to_part(partition_book, edges)
if num_client_per_part == 1: if num_client_per_part == 1:
return part_eid return part_eid
elif edge_trainer_ids is None: elif edge_trainer_ids is None:
return _split_random_within_part(partition_book, rank, part_eid) return _split_random_within_part(partition_book, rank, part_eid)
else: else:
trainer_id = edge_trainer_ids[0:len(edge_trainer_ids)] trainer_id = edge_trainer_ids[0 : len(edge_trainer_ids)]
max_trainer_id = F.as_scalar(F.reduce_max(trainer_id)) + 1 max_trainer_id = F.as_scalar(F.reduce_max(trainer_id)) + 1
if max_trainer_id > num_clients: if max_trainer_id > num_clients:
...@@ -1578,14 +1784,22 @@ def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=Tru ...@@ -1578,14 +1784,22 @@ def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=Tru
# trainers is less than the `num_trainers_per_machine` previously assigned during # trainers is less than the `num_trainers_per_machine` previously assigned during
# partitioning. # partitioning.
assert max_trainer_id % num_clients == 0 assert max_trainer_id % num_clients == 0
trainer_id //= (max_trainer_id // num_clients) trainer_id //= max_trainer_id // num_clients
client_id_in_part = rank % num_client_per_part client_id_in_part = rank % num_client_per_part
return _split_by_trainer_id(partition_book, part_eid, trainer_id, return _split_by_trainer_id(
num_client_per_part, client_id_in_part) partition_book,
part_eid,
trainer_id,
num_client_per_part,
client_id_in_part,
)
else: else:
# Get all edges that belong to the rank. # Get all edges that belong to the rank.
local_eids = partition_book.partid2eids(partition_book.partid, etype=etype) local_eids = partition_book.partid2eids(
partition_book.partid, etype=etype
)
return _split_local(partition_book, rank, edges, local_eids) return _split_local(partition_book, rank, edges, local_eids)
rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse) rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)
...@@ -2,21 +2,24 @@ ...@@ -2,21 +2,24 @@
import os import os
from .. import backend as F, utils
from .dist_context import is_initialized from .dist_context import is_initialized
from .kvstore import get_kvstore from .kvstore import get_kvstore
from .role import get_role from .role import get_role
from .. import utils
from .. import backend as F
from .rpc import get_group_id from .rpc import get_group_id
def _default_init_data(shape, dtype): def _default_init_data(shape, dtype):
return F.zeros(shape, dtype, F.cpu()) return F.zeros(shape, dtype, F.cpu())
# These IDs can identify the anonymous distributed tensors. # These IDs can identify the anonymous distributed tensors.
DIST_TENSOR_ID = 0 DIST_TENSOR_ID = 0
class DistTensor: class DistTensor:
''' Distributed tensor. """Distributed tensor.
``DistTensor`` references to a distributed tensor sharded and stored in a cluster of machines. ``DistTensor`` references to a distributed tensor sharded and stored in a cluster of machines.
It has the same interface as Pytorch Tensor to access its metadata (e.g., shape and data type). It has the same interface as Pytorch Tensor to access its metadata (e.g., shape and data type).
...@@ -103,12 +106,23 @@ class DistTensor: ...@@ -103,12 +106,23 @@ class DistTensor:
The creation of ``DistTensor`` is a synchronized operation. When a trainer process tries to The creation of ``DistTensor`` is a synchronized operation. When a trainer process tries to
create a ``DistTensor`` object, the creation succeeds only when all trainer processes create a ``DistTensor`` object, the creation succeeds only when all trainer processes
do the same. do the same.
''' """
def __init__(self, shape, dtype, name=None, init_func=None, part_policy=None,
persistent=False, is_gdata=True, attach=True): def __init__(
self,
shape,
dtype,
name=None,
init_func=None,
part_policy=None,
persistent=False,
is_gdata=True,
attach=True,
):
self.kvstore = get_kvstore() self.kvstore = get_kvstore()
assert self.kvstore is not None, \ assert (
'Distributed module is not initialized. Please call dgl.distributed.initialize.' self.kvstore is not None
), "Distributed module is not initialized. Please call dgl.distributed.initialize."
self._shape = shape self._shape = shape
self._dtype = dtype self._dtype = dtype
self._attach = attach self._attach = attach
...@@ -124,18 +138,21 @@ class DistTensor: ...@@ -124,18 +138,21 @@ class DistTensor:
# If multiple partition policies match the input shape, we cannot # If multiple partition policies match the input shape, we cannot
# decide which is the right one automatically. We should ask users # decide which is the right one automatically. We should ask users
# to provide one. # to provide one.
assert part_policy is None, \ assert part_policy is None, (
'Multiple partition policies match the input shape. ' \ "Multiple partition policies match the input shape. "
+ 'Please provide a partition policy explicitly.' + "Please provide a partition policy explicitly."
)
part_policy = policy part_policy = policy
assert part_policy is not None, \ assert part_policy is not None, (
'Cannot find a right partition policy. It is either because ' \ "Cannot find a right partition policy. It is either because "
+ 'its first dimension does not match the number of nodes or edges ' \ + "its first dimension does not match the number of nodes or edges "
+ 'of a distributed graph or there does not exist a distributed graph.' + "of a distributed graph or there does not exist a distributed graph."
)
self._part_policy = part_policy self._part_policy = part_policy
assert part_policy.get_size() == shape[0], \ assert (
'The partition policy does not match the input shape.' part_policy.get_size() == shape[0]
), "The partition policy does not match the input shape."
if init_func is None: if init_func is None:
init_func = _default_init_data init_func = _default_init_data
...@@ -143,13 +160,17 @@ class DistTensor: ...@@ -143,13 +160,17 @@ class DistTensor:
# If a user doesn't provide a name, we generate a name ourselves. # If a user doesn't provide a name, we generate a name ourselves.
# We need to generate the name in a deterministic way. # We need to generate the name in a deterministic way.
if name is None: if name is None:
assert not persistent, 'We cannot generate anonymous persistent distributed tensors' assert (
not persistent
), "We cannot generate anonymous persistent distributed tensors"
global DIST_TENSOR_ID global DIST_TENSOR_ID
# All processes of the same role should create DistTensor synchronously. # All processes of the same role should create DistTensor synchronously.
# Thus, all of them should have the same IDs. # Thus, all of them should have the same IDs.
name = 'anonymous-' + get_role() + '-' + str(DIST_TENSOR_ID) name = "anonymous-" + get_role() + "-" + str(DIST_TENSOR_ID)
DIST_TENSOR_ID += 1 DIST_TENSOR_ID += 1
assert isinstance(name, str), 'name {} is type {}'.format(name, type(name)) assert isinstance(name, str), "name {} is type {}".format(
name, type(name)
)
name = self._attach_group_id(name) name = self._attach_group_id(name)
self._tensor_name = name self._tensor_name = name
data_name = part_policy.get_data_name(name) data_name = part_policy.get_data_name(name)
...@@ -157,16 +178,24 @@ class DistTensor: ...@@ -157,16 +178,24 @@ class DistTensor:
self._persistent = persistent self._persistent = persistent
if self._name not in exist_names: if self._name not in exist_names:
self._owner = True self._owner = True
self.kvstore.init_data(self._name, shape, dtype, part_policy, init_func, is_gdata) self.kvstore.init_data(
self._name, shape, dtype, part_policy, init_func, is_gdata
)
else: else:
self._owner = False self._owner = False
dtype1, shape1, _ = self.kvstore.get_data_meta(self._name) dtype1, shape1, _ = self.kvstore.get_data_meta(self._name)
assert dtype == dtype1, 'The dtype does not match with the existing tensor' assert (
assert shape == shape1, 'The shape does not match with the existing tensor' dtype == dtype1
), "The dtype does not match with the existing tensor"
assert (
shape == shape1
), "The shape does not match with the existing tensor"
def __del__(self): def __del__(self):
initialized = os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone' \ initialized = (
os.environ.get("DGL_DIST_MODE", "standalone") == "standalone"
or is_initialized() or is_initialized()
)
if not self._persistent and self._owner and initialized: if not self._persistent and self._owner and initialized:
self.kvstore.delete_data(self._name) self.kvstore.delete_data(self._name)
...@@ -198,7 +227,7 @@ class DistTensor: ...@@ -198,7 +227,7 @@ class DistTensor:
part_policy=self._part_policy, part_policy=self._part_policy,
persistent=self._persistent, persistent=self._persistent,
is_gdata=self._is_gdata, is_gdata=self._is_gdata,
attach=self._attach attach=self._attach,
) )
kvstore = self.kvstore kvstore = self.kvstore
kvstore.union(self._name, other._name, new_dist_tensor._name) kvstore.union(self._name, other._name, new_dist_tensor._name)
...@@ -209,67 +238,67 @@ class DistTensor: ...@@ -209,67 +238,67 @@ class DistTensor:
@property @property
def part_policy(self): def part_policy(self):
'''Return the partition policy """Return the partition policy
Returns Returns
------- -------
PartitionPolicy PartitionPolicy
The partition policy of the distributed tensor. The partition policy of the distributed tensor.
''' """
return self._part_policy return self._part_policy
@property @property
def shape(self): def shape(self):
'''Return the shape of the distributed tensor. """Return the shape of the distributed tensor.
Returns Returns
------- -------
tuple tuple
The shape of the distributed tensor. The shape of the distributed tensor.
''' """
return self._shape return self._shape
@property @property
def dtype(self): def dtype(self):
'''Return the data type of the distributed tensor. """Return the data type of the distributed tensor.
Returns Returns
------ ------
dtype dtype
The data type of the tensor. The data type of the tensor.
''' """
return self._dtype return self._dtype
@property @property
def name(self): def name(self):
'''Return the name of the distributed tensor """Return the name of the distributed tensor
Returns Returns
------- -------
str str
The name of the tensor. The name of the tensor.
''' """
return self._detach_group_id(self._name) return self._detach_group_id(self._name)
@property @property
def tensor_name(self): def tensor_name(self):
'''Return the tensor name """Return the tensor name
Returns Returns
------- -------
str str
The name of the tensor. The name of the tensor.
''' """
return self._detach_group_id(self._tensor_name) return self._detach_group_id(self._tensor_name)
def count_nonzero(self): def count_nonzero(self):
'''Count and return the number of nonzero value """Count and return the number of nonzero value
Returns Returns
------- -------
int int
the number of nonzero value the number of nonzero value
''' """
return self.kvstore.count_nonzero(name=self._name) return self.kvstore.count_nonzero(name=self._name)
def _attach_group_id(self, name): def _attach_group_id(self, name):
...@@ -295,4 +324,4 @@ class DistTensor: ...@@ -295,4 +324,4 @@ class DistTensor:
if not self._attach: if not self._attach:
return name return name
suffix = "_{}".format(get_group_id()) suffix = "_{}".format(get_group_id())
return name[:-len(suffix)] return name[: -len(suffix)]
...@@ -5,8 +5,7 @@ from abc import ABC ...@@ -5,8 +5,7 @@ from abc import ABC
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F, utils
from .. import utils
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from ..base import DGLError from ..base import DGLError
from ..ndarray import exist_shared_mem_array from ..ndarray import exist_shared_mem_array
...@@ -14,16 +13,17 @@ from ..partition import NDArrayPartition ...@@ -14,16 +13,17 @@ from ..partition import NDArrayPartition
from .constants import DEFAULT_ETYPE, DEFAULT_NTYPE from .constants import DEFAULT_ETYPE, DEFAULT_NTYPE
from .id_map import IdMap from .id_map import IdMap
from .shared_mem_utils import ( from .shared_mem_utils import (
DTYPE_DICT,
_get_edata_path, _get_edata_path,
_get_ndata_path, _get_ndata_path,
_to_shared_mem, _to_shared_mem,
DTYPE_DICT,
) )
CANONICAL_ETYPE_DELIMITER = ":" CANONICAL_ETYPE_DELIMITER = ":"
def _etype_tuple_to_str(c_etype): def _etype_tuple_to_str(c_etype):
'''Convert canonical etype from tuple to string. """Convert canonical etype from tuple to string.
Examples Examples
-------- --------
...@@ -32,14 +32,16 @@ def _etype_tuple_to_str(c_etype): ...@@ -32,14 +32,16 @@ def _etype_tuple_to_str(c_etype):
>>> print(c_etype_str) >>> print(c_etype_str)
'user:like:item' 'user:like:item'
''' """
assert isinstance(c_etype, tuple) and len(c_etype) == 3, \ assert isinstance(c_etype, tuple) and len(c_etype) == 3, (
"Passed-in canonical etype should be in format of (str, str, str). " \ "Passed-in canonical etype should be in format of (str, str, str). "
f"But got {c_etype}." f"But got {c_etype}."
)
return CANONICAL_ETYPE_DELIMITER.join(c_etype) return CANONICAL_ETYPE_DELIMITER.join(c_etype)
def _etype_str_to_tuple(c_etype): def _etype_str_to_tuple(c_etype):
'''Convert canonical etype from tuple to string. """Convert canonical etype from tuple to string.
Examples Examples
-------- --------
...@@ -48,13 +50,15 @@ def _etype_str_to_tuple(c_etype): ...@@ -48,13 +50,15 @@ def _etype_str_to_tuple(c_etype):
>>> print(c_etype) >>> print(c_etype)
('user', 'like', 'item') ('user', 'like', 'item')
''' """
ret = tuple(c_etype.split(CANONICAL_ETYPE_DELIMITER)) ret = tuple(c_etype.split(CANONICAL_ETYPE_DELIMITER))
assert len(ret) == 3, \ assert len(ret) == 3, (
"Passed-in canonical etype should be in format of 'str:str:str'. " \ "Passed-in canonical etype should be in format of 'str:str:str'. "
f"But got {c_etype}." f"But got {c_etype}."
)
return ret return ret
def _move_metadata_to_shared_mem( def _move_metadata_to_shared_mem(
graph_name, graph_name,
num_nodes, num_nodes,
...@@ -533,6 +537,7 @@ class GraphPartitionBook(ABC): ...@@ -533,6 +537,7 @@ class GraphPartitionBook(ABC):
Homogeneous edge IDs. Homogeneous edge IDs.
""" """
class RangePartitionBook(GraphPartitionBook): class RangePartitionBook(GraphPartitionBook):
"""This partition book supports more efficient storage of partition information. """This partition book supports more efficient storage of partition information.
...@@ -582,9 +587,10 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -582,9 +587,10 @@ class RangePartitionBook(GraphPartitionBook):
ntype is not None for ntype in self._ntypes ntype is not None for ntype in self._ntypes
), "The node types have invalid IDs." ), "The node types have invalid IDs."
for c_etype, etype_id in etypes.items(): for c_etype, etype_id in etypes.items():
assert isinstance(c_etype, tuple) and len(c_etype) == 3, \ assert isinstance(c_etype, tuple) and len(c_etype) == 3, (
"Expect canonical edge type in a triplet of string, but got " \ "Expect canonical edge type in a triplet of string, but got "
f"{c_etype}." f"{c_etype}."
)
etype = c_etype[1] etype = c_etype[1]
self._etypes[etype_id] = etype self._etypes[etype_id] = etype
self._canonical_etypes[etype_id] = c_etype self._canonical_etypes[etype_id] = c_etype
...@@ -660,13 +666,19 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -660,13 +666,19 @@ class RangePartitionBook(GraphPartitionBook):
# to local heterogenized node/edge IDs. One can do the mapping by binary search # to local heterogenized node/edge IDs. One can do the mapping by binary search
# on these arrays. # on these arrays.
self._local_ntype_offset = np.cumsum( self._local_ntype_offset = np.cumsum(
[0] + [ [0]
+ [
v[self._partid, 1] - v[self._partid, 0] v[self._partid, 1] - v[self._partid, 0]
for v in self._typed_nid_range.values()]).tolist() for v in self._typed_nid_range.values()
]
).tolist()
self._local_etype_offset = np.cumsum( self._local_etype_offset = np.cumsum(
[0] + [ [0]
+ [
v[self._partid, 1] - v[self._partid, 0] v[self._partid, 1] - v[self._partid, 0]
for v in self._typed_eid_range.values()]).tolist() for v in self._typed_eid_range.values()
]
).tolist()
# Get meta data of the partition book # Get meta data of the partition book
self._partition_meta_data = [] self._partition_meta_data = []
...@@ -945,7 +957,7 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -945,7 +957,7 @@ class RangePartitionBook(GraphPartitionBook):
NODE_PART_POLICY = "node" NODE_PART_POLICY = "node"
EDGE_PART_POLICY = "edge" EDGE_PART_POLICY = "edge"
POLICY_DELIMITER = '~' POLICY_DELIMITER = "~"
class PartitionPolicy(object): class PartitionPolicy(object):
...@@ -967,8 +979,9 @@ class PartitionPolicy(object): ...@@ -967,8 +979,9 @@ class PartitionPolicy(object):
""" """
def __init__(self, policy_str, partition_book): def __init__(self, policy_str, partition_book):
assert (policy_str.startswith(NODE_PART_POLICY) or assert policy_str.startswith(NODE_PART_POLICY) or policy_str.startswith(
policy_str.startswith(EDGE_PART_POLICY)), ( EDGE_PART_POLICY
), (
f"policy_str must start with {NODE_PART_POLICY} or " f"policy_str must start with {NODE_PART_POLICY} or "
f"{EDGE_PART_POLICY}, but got {policy_str}." f"{EDGE_PART_POLICY}, but got {policy_str}."
) )
...@@ -1127,11 +1140,12 @@ class EdgePartitionPolicy(PartitionPolicy): ...@@ -1127,11 +1140,12 @@ class EdgePartitionPolicy(PartitionPolicy):
"""Partition policy for edges.""" """Partition policy for edges."""
def __init__(self, partition_book, etype=DEFAULT_ETYPE): def __init__(self, partition_book, etype=DEFAULT_ETYPE):
assert isinstance(etype, tuple) and len(etype) == 3, \ assert (
f"Expect canonical edge type in a triplet of string, but got {etype}." isinstance(etype, tuple) and len(etype) == 3
), f"Expect canonical edge type in a triplet of string, but got {etype}."
super(EdgePartitionPolicy, self).__init__( super(EdgePartitionPolicy, self).__init__(
EDGE_PART_POLICY + POLICY_DELIMITER + _etype_tuple_to_str(etype), EDGE_PART_POLICY + POLICY_DELIMITER + _etype_tuple_to_str(etype),
partition_book partition_book,
) )
...@@ -1156,9 +1170,10 @@ class HeteroDataName(object): ...@@ -1156,9 +1170,10 @@ class HeteroDataName(object):
def __init__(self, is_node, entity_type, data_name): def __init__(self, is_node, entity_type, data_name):
self._policy = NODE_PART_POLICY if is_node else EDGE_PART_POLICY self._policy = NODE_PART_POLICY if is_node else EDGE_PART_POLICY
if not is_node: if not is_node:
assert isinstance(entity_type, tuple) and len(entity_type) == 3, \ assert isinstance(entity_type, tuple) and len(entity_type) == 3, (
"Expect canonical edge type in a triplet of string, but got " \ "Expect canonical edge type in a triplet of string, but got "
f"{entity_type}." f"{entity_type}."
)
self._entity_type = entity_type self._entity_type = entity_type
self.data_name = data_name self.data_name = data_name
...@@ -1226,6 +1241,4 @@ def parse_hetero_data_name(name): ...@@ -1226,6 +1241,4 @@ def parse_hetero_data_name(name):
entity_type = names[1] entity_type = names[1]
if not is_node: if not is_node:
entity_type = _etype_str_to_tuple(entity_type) entity_type = _etype_str_to_tuple(entity_type)
return HeteroDataName( return HeteroDataName(is_node, entity_type, names[2])
is_node, entity_type, names[2]
)
...@@ -6,16 +6,17 @@ import numpy as np ...@@ -6,16 +6,17 @@ import numpy as np
from .. import backend as F from .. import backend as F
from ..base import EID, NID from ..base import EID, NID
from ..convert import graph, heterograph from ..convert import graph, heterograph
from ..sampling import sample_etype_neighbors as local_sample_etype_neighbors from ..sampling import (
from ..sampling import sample_neighbors as local_sample_neighbors sample_etype_neighbors as local_sample_etype_neighbors,
sample_neighbors as local_sample_neighbors,
)
from ..subgraph import in_subgraph as local_in_subgraph from ..subgraph import in_subgraph as local_in_subgraph
from ..utils import toindex from ..utils import toindex
from .. import backend as F
from .rpc import ( from .rpc import (
Request,
Response,
recv_responses, recv_responses,
register_service, register_service,
Request,
Response,
send_requests_to_machine, send_requests_to_machine,
) )
...@@ -207,6 +208,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes): ...@@ -207,6 +208,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
# This is a limitation of the current DistDGL design. We should improve it # This is a limitation of the current DistDGL design. We should improve it
# later. # later.
class SamplingRequest(Request): class SamplingRequest(Request):
"""Sampling Request""" """Sampling Request"""
...@@ -798,9 +800,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False): ...@@ -798,9 +800,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
# See NOTE 1 # See NOTE 1
_prob = ( _prob = [g.edata[prob].local_partition] if prob is not None else None
[g.edata[prob].local_partition] if prob is not None else None
)
return _sample_neighbors( return _sample_neighbors(
local_g, local_g,
partition_book, partition_book,
......
"""Module for mapping between node/edge IDs and node/edge types.""" """Module for mapping between node/edge IDs and node/edge types."""
import numpy as np import numpy as np
from .. import backend as F, utils
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F
from .. import utils
class IdMap: class IdMap:
'''A map for converting node/edge IDs to their type IDs and type-wise IDs. """A map for converting node/edge IDs to their type IDs and type-wise IDs.
For a heterogeneous graph, DGL assigns an integer ID to each node/edge type; For a heterogeneous graph, DGL assigns an integer ID to each node/edge type;
node and edge of different types have independent IDs starting from zero. node and edge of different types have independent IDs starting from zero.
...@@ -96,7 +97,8 @@ class IdMap: ...@@ -96,7 +97,8 @@ class IdMap:
for a particular node type in a partition. For example, all nodes of type ``"T"`` in for a particular node type in a partition. For example, all nodes of type ``"T"`` in
partition ``i`` has ID range ``id_ranges["T"][i][0]`` to ``id_ranges["T"][i][1]``. partition ``i`` has ID range ``id_ranges["T"][i][0]`` to ``id_ranges["T"][i][1]``.
It is the same as the `node_map` argument in `RangePartitionBook`. It is the same as the `node_map` argument in `RangePartitionBook`.
''' """
def __init__(self, id_ranges): def __init__(self, id_ranges):
self.num_parts = list(id_ranges.values())[0].shape[0] self.num_parts = list(id_ranges.values())[0].shape[0]
self.num_types = len(id_ranges) self.num_types = len(id_ranges)
...@@ -105,7 +107,7 @@ class IdMap: ...@@ -105,7 +107,7 @@ class IdMap:
id_ranges = list(id_ranges.values()) id_ranges = list(id_ranges.values())
id_ranges.sort(key=lambda a: a[0, 0]) id_ranges.sort(key=lambda a: a[0, 0])
for i, id_range in enumerate(id_ranges): for i, id_range in enumerate(id_ranges):
ranges[i::self.num_types] = id_range ranges[i :: self.num_types] = id_range
map1 = np.cumsum(id_range[:, 1] - id_range[:, 0]) map1 = np.cumsum(id_range[:, 1] - id_range[:, 0])
typed_map.append(map1) typed_map.append(map1)
...@@ -116,7 +118,7 @@ class IdMap: ...@@ -116,7 +118,7 @@ class IdMap:
self.typed_map = utils.toindex(np.concatenate(typed_map)) self.typed_map = utils.toindex(np.concatenate(typed_map))
def __call__(self, ids): def __call__(self, ids):
'''Convert the homogeneous IDs to (type_id, type_wise_id). """Convert the homogeneous IDs to (type_id, type_wise_id).
Parameters Parameters
---------- ----------
...@@ -129,19 +131,23 @@ class IdMap: ...@@ -129,19 +131,23 @@ class IdMap:
Type IDs Type IDs
per_type_ids : Tensor per_type_ids : Tensor
Type-wise IDs Type-wise IDs
''' """
if self.num_types == 0: if self.num_types == 0:
return F.zeros((len(ids),), F.dtype(ids), F.cpu()), ids return F.zeros((len(ids),), F.dtype(ids), F.cpu()), ids
if len(ids) == 0: if len(ids) == 0:
return ids, ids return ids, ids
ids = utils.toindex(ids) ids = utils.toindex(ids)
ret = _CAPI_DGLHeteroMapIds(ids.todgltensor(), ret = _CAPI_DGLHeteroMapIds(
ids.todgltensor(),
self.range_start.todgltensor(), self.range_start.todgltensor(),
self.range_end.todgltensor(), self.range_end.todgltensor(),
self.typed_map.todgltensor(), self.typed_map.todgltensor(),
self.num_parts, self.num_types) self.num_parts,
self.num_types,
)
ret = utils.toindex(ret).tousertensor() ret = utils.toindex(ret).tousertensor()
return ret[:len(ids)], ret[len(ids):] return ret[: len(ids)], ret[len(ids) :]
_init_api("dgl.distributed.id_map") _init_api("dgl.distributed.id_map")
"""Define distributed kvstore""" """Define distributed kvstore"""
import os import os
import numpy as np import numpy as np
from .. import backend as F, utils
from .._ffi.ndarray import empty_shared_mem
from . import rpc from . import rpc
from .graph_partition_book import NodePartitionPolicy, EdgePartitionPolicy from .graph_partition_book import EdgePartitionPolicy, NodePartitionPolicy
from .standalone_kvstore import KVClient as SA_KVClient from .standalone_kvstore import KVClient as SA_KVClient
from .. import backend as F
from .. import utils
from .._ffi.ndarray import empty_shared_mem
############################ Register KVStore Requsts and Responses ############################### ############################ Register KVStore Requsts and Responses ###############################
KVSTORE_PULL = 901231 KVSTORE_PULL = 901231
class PullResponse(rpc.Response): class PullResponse(rpc.Response):
"""Send the sliced data tensor back to the client. """Send the sliced data tensor back to the client.
...@@ -25,6 +26,7 @@ class PullResponse(rpc.Response): ...@@ -25,6 +26,7 @@ class PullResponse(rpc.Response):
data_tensor : tensor data_tensor : tensor
sliced data tensor sliced data tensor
""" """
def __init__(self, server_id, data_tensor): def __init__(self, server_id, data_tensor):
self.server_id = server_id self.server_id = server_id
self.data_tensor = data_tensor self.data_tensor = data_tensor
...@@ -35,6 +37,7 @@ class PullResponse(rpc.Response): ...@@ -35,6 +37,7 @@ class PullResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self.server_id, self.data_tensor = state self.server_id, self.data_tensor = state
class PullRequest(rpc.Request): class PullRequest(rpc.Request):
"""Send ID tensor to server and get target data tensor as response. """Send ID tensor to server and get target data tensor as response.
...@@ -45,6 +48,7 @@ class PullRequest(rpc.Request): ...@@ -45,6 +48,7 @@ class PullRequest(rpc.Request):
id_tensor : tensor id_tensor : tensor
a vector storing the data ID a vector storing the data ID
""" """
def __init__(self, name, id_tensor): def __init__(self, name, id_tensor):
self.name = name self.name = name
self.id_tensor = id_tensor self.id_tensor = id_tensor
...@@ -58,16 +62,25 @@ class PullRequest(rpc.Request): ...@@ -58,16 +62,25 @@ class PullRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
if self.name not in kv_store.part_policy: if self.name not in kv_store.part_policy:
raise RuntimeError("KVServer cannot find partition policy with name: %s" % self.name) raise RuntimeError(
"KVServer cannot find partition policy with name: %s"
% self.name
)
if self.name not in kv_store.data_store: if self.name not in kv_store.data_store:
raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name) raise RuntimeError(
"KVServer Cannot find data tensor with name: %s" % self.name
)
local_id = kv_store.part_policy[self.name].to_local(self.id_tensor) local_id = kv_store.part_policy[self.name].to_local(self.id_tensor)
data = kv_store.pull_handlers[self.name](kv_store.data_store, self.name, local_id) data = kv_store.pull_handlers[self.name](
kv_store.data_store, self.name, local_id
)
res = PullResponse(kv_store.server_id, data) res = PullResponse(kv_store.server_id, data)
return res return res
KVSTORE_PUSH = 901232 KVSTORE_PUSH = 901232
class PushRequest(rpc.Request): class PushRequest(rpc.Request):
"""Send ID tensor and data tensor to server and update kvstore's data. """Send ID tensor and data tensor to server and update kvstore's data.
...@@ -82,6 +95,7 @@ class PushRequest(rpc.Request): ...@@ -82,6 +95,7 @@ class PushRequest(rpc.Request):
data_tensor : tensor data_tensor : tensor
a tensor with the same row size of data ID a tensor with the same row size of data ID
""" """
def __init__(self, name, id_tensor, data_tensor): def __init__(self, name, id_tensor, data_tensor):
self.name = name self.name = name
self.id_tensor = id_tensor self.id_tensor = id_tensor
...@@ -96,15 +110,23 @@ class PushRequest(rpc.Request): ...@@ -96,15 +110,23 @@ class PushRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
if self.name not in kv_store.part_policy: if self.name not in kv_store.part_policy:
raise RuntimeError("KVServer cannot find partition policy with name: %s" % self.name) raise RuntimeError(
"KVServer cannot find partition policy with name: %s"
% self.name
)
if self.name not in kv_store.data_store: if self.name not in kv_store.data_store:
raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name) raise RuntimeError(
"KVServer Cannot find data tensor with name: %s" % self.name
)
local_id = kv_store.part_policy[self.name].to_local(self.id_tensor) local_id = kv_store.part_policy[self.name].to_local(self.id_tensor)
kv_store.push_handlers[self.name](kv_store.data_store, self.name, kv_store.push_handlers[self.name](
local_id, self.data_tensor) kv_store.data_store, self.name, local_id, self.data_tensor
)
INIT_DATA = 901233 INIT_DATA = 901233
INIT_MSG = 'Init' INIT_MSG = "Init"
class InitDataResponse(rpc.Response): class InitDataResponse(rpc.Response):
"""Send a confirmation response (just a short string message) of """Send a confirmation response (just a short string message) of
...@@ -115,6 +137,7 @@ class InitDataResponse(rpc.Response): ...@@ -115,6 +137,7 @@ class InitDataResponse(rpc.Response):
msg : string msg : string
string message string message
""" """
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
...@@ -124,6 +147,7 @@ class InitDataResponse(rpc.Response): ...@@ -124,6 +147,7 @@ class InitDataResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self.msg = state self.msg = state
class InitDataRequest(rpc.Request): class InitDataRequest(rpc.Request):
"""Send meta data to server and init data tensor """Send meta data to server and init data tensor
on server using UDF init function. on server using UDF init function.
...@@ -141,6 +165,7 @@ class InitDataRequest(rpc.Request): ...@@ -141,6 +165,7 @@ class InitDataRequest(rpc.Request):
init_func : function init_func : function
UDF init function. UDF init function.
""" """
def __init__(self, name, shape, dtype, policy_str, init_func): def __init__(self, name, shape, dtype, policy_str, init_func):
self.name = name self.name = name
self.shape = shape self.shape = shape
...@@ -149,10 +174,22 @@ class InitDataRequest(rpc.Request): ...@@ -149,10 +174,22 @@ class InitDataRequest(rpc.Request):
self.init_func = init_func self.init_func = init_func
def __getstate__(self): def __getstate__(self):
return self.name, self.shape, self.dtype, self.policy_str, self.init_func return (
self.name,
self.shape,
self.dtype,
self.policy_str,
self.init_func,
)
def __setstate__(self, state): def __setstate__(self, state):
self.name, self.shape, self.dtype, self.policy_str, self.init_func = state (
self.name,
self.shape,
self.dtype,
self.policy_str,
self.init_func,
) = state
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
...@@ -161,22 +198,33 @@ class InitDataRequest(rpc.Request): ...@@ -161,22 +198,33 @@ class InitDataRequest(rpc.Request):
# We should see requests from multiple clients. We need to ignore the duplicated # We should see requests from multiple clients. We need to ignore the duplicated
# reqeusts. # reqeusts.
if self.name in kv_store.data_store: if self.name in kv_store.data_store:
assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(self.shape) assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(
assert F.reverse_data_type_dict[F.dtype(kv_store.data_store[self.name])] == self.dtype self.shape
)
assert (
F.reverse_data_type_dict[
F.dtype(kv_store.data_store[self.name])
]
== self.dtype
)
assert kv_store.part_policy[self.name].policy_str == self.policy_str assert kv_store.part_policy[self.name].policy_str == self.policy_str
else: else:
if not kv_store.is_backup_server(): if not kv_store.is_backup_server():
data_tensor = self.init_func(self.shape, dtype) data_tensor = self.init_func(self.shape, dtype)
kv_store.init_data(name=self.name, kv_store.init_data(
name=self.name,
policy_str=self.policy_str, policy_str=self.policy_str,
data_tensor=data_tensor) data_tensor=data_tensor,
)
else: else:
kv_store.init_data(name=self.name, policy_str=self.policy_str) kv_store.init_data(name=self.name, policy_str=self.policy_str)
res = InitDataResponse(INIT_MSG) res = InitDataResponse(INIT_MSG)
return res return res
BARRIER = 901234 BARRIER = 901234
BARRIER_MSG = 'Barrier' BARRIER_MSG = "Barrier"
class BarrierResponse(rpc.Response): class BarrierResponse(rpc.Response):
"""Send an confimation signal (just a short string message) of """Send an confimation signal (just a short string message) of
...@@ -187,6 +235,7 @@ class BarrierResponse(rpc.Response): ...@@ -187,6 +235,7 @@ class BarrierResponse(rpc.Response):
msg : string msg : string
string msg string msg
""" """
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
...@@ -196,6 +245,7 @@ class BarrierResponse(rpc.Response): ...@@ -196,6 +245,7 @@ class BarrierResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self.msg = state self.msg = state
class BarrierRequest(rpc.Request): class BarrierRequest(rpc.Request):
"""Send a barrier signal (just a short string message) to server. """Send a barrier signal (just a short string message) to server.
...@@ -204,6 +254,7 @@ class BarrierRequest(rpc.Request): ...@@ -204,6 +254,7 @@ class BarrierRequest(rpc.Request):
role : string role : string
client role client role
""" """
def __init__(self, role): def __init__(self, role):
self.role = role self.role = role
self.group_id = rpc.get_group_id() self.group_id = rpc.get_group_id()
...@@ -229,8 +280,10 @@ class BarrierRequest(rpc.Request): ...@@ -229,8 +280,10 @@ class BarrierRequest(rpc.Request):
return res_list return res_list
return None return None
REGISTER_PULL = 901235 REGISTER_PULL = 901235
REGISTER_PULL_MSG = 'Register_Pull' REGISTER_PULL_MSG = "Register_Pull"
class RegisterPullHandlerResponse(rpc.Response): class RegisterPullHandlerResponse(rpc.Response):
"""Send a confirmation signal (just a short string message) of """Send a confirmation signal (just a short string message) of
...@@ -241,6 +294,7 @@ class RegisterPullHandlerResponse(rpc.Response): ...@@ -241,6 +294,7 @@ class RegisterPullHandlerResponse(rpc.Response):
msg : string msg : string
string message string message
""" """
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
...@@ -250,6 +304,7 @@ class RegisterPullHandlerResponse(rpc.Response): ...@@ -250,6 +304,7 @@ class RegisterPullHandlerResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self.msg = state self.msg = state
class RegisterPullHandlerRequest(rpc.Request): class RegisterPullHandlerRequest(rpc.Request):
"""Send an UDF and register Pull handler on server. """Send an UDF and register Pull handler on server.
...@@ -258,6 +313,7 @@ class RegisterPullHandlerRequest(rpc.Request): ...@@ -258,6 +313,7 @@ class RegisterPullHandlerRequest(rpc.Request):
pull_func : func pull_func : func
UDF pull handler UDF pull handler
""" """
def __init__(self, name, pull_func): def __init__(self, name, pull_func):
self.name = name self.name = name
self.pull_func = pull_func self.pull_func = pull_func
...@@ -274,8 +330,10 @@ class RegisterPullHandlerRequest(rpc.Request): ...@@ -274,8 +330,10 @@ class RegisterPullHandlerRequest(rpc.Request):
res = RegisterPullHandlerResponse(REGISTER_PULL_MSG) res = RegisterPullHandlerResponse(REGISTER_PULL_MSG)
return res return res
REGISTER_PUSH = 901236 REGISTER_PUSH = 901236
REGISTER_PUSH_MSG = 'Register_Push' REGISTER_PUSH_MSG = "Register_Push"
class RegisterPushHandlerResponse(rpc.Response): class RegisterPushHandlerResponse(rpc.Response):
"""Send a confirmation signal (just a short string message) of """Send a confirmation signal (just a short string message) of
...@@ -286,6 +344,7 @@ class RegisterPushHandlerResponse(rpc.Response): ...@@ -286,6 +344,7 @@ class RegisterPushHandlerResponse(rpc.Response):
msg : string msg : string
string message string message
""" """
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
...@@ -295,6 +354,7 @@ class RegisterPushHandlerResponse(rpc.Response): ...@@ -295,6 +354,7 @@ class RegisterPushHandlerResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self.msg = state self.msg = state
class RegisterPushHandlerRequest(rpc.Request): class RegisterPushHandlerRequest(rpc.Request):
"""Send an UDF to register Push handler on server. """Send an UDF to register Push handler on server.
...@@ -303,6 +363,7 @@ class RegisterPushHandlerRequest(rpc.Request): ...@@ -303,6 +363,7 @@ class RegisterPushHandlerRequest(rpc.Request):
push_func : func push_func : func
UDF push handler UDF push handler
""" """
def __init__(self, name, push_func): def __init__(self, name, push_func):
self.name = name self.name = name
self.push_func = push_func self.push_func = push_func
...@@ -319,8 +380,10 @@ class RegisterPushHandlerRequest(rpc.Request): ...@@ -319,8 +380,10 @@ class RegisterPushHandlerRequest(rpc.Request):
res = RegisterPushHandlerResponse(REGISTER_PUSH_MSG) res = RegisterPushHandlerResponse(REGISTER_PUSH_MSG)
return res return res
GET_SHARED = 901237 GET_SHARED = 901237
GET_SHARED_MSG = 'Get_Shared' GET_SHARED_MSG = "Get_Shared"
class GetSharedDataResponse(rpc.Response): class GetSharedDataResponse(rpc.Response):
"""Send meta data of shared-memory tensor to client. """Send meta data of shared-memory tensor to client.
...@@ -333,6 +396,7 @@ class GetSharedDataResponse(rpc.Response): ...@@ -333,6 +396,7 @@ class GetSharedDataResponse(rpc.Response):
{'data_0' : (shape, dtype, policy_str), {'data_0' : (shape, dtype, policy_str),
'data_1' : (shape, dtype, policy_str)} 'data_1' : (shape, dtype, policy_str)}
""" """
def __init__(self, meta): def __init__(self, meta):
self.meta = meta self.meta = meta
...@@ -342,6 +406,7 @@ class GetSharedDataResponse(rpc.Response): ...@@ -342,6 +406,7 @@ class GetSharedDataResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self.meta = state self.meta = state
class GetSharedDataRequest(rpc.Request): class GetSharedDataRequest(rpc.Request):
"""Send a signal (just a short string message) to get the """Send a signal (just a short string message) to get the
meta data of shared-tensor from server. meta data of shared-tensor from server.
...@@ -351,6 +416,7 @@ class GetSharedDataRequest(rpc.Request): ...@@ -351,6 +416,7 @@ class GetSharedDataRequest(rpc.Request):
msg : string msg : string
string message string message
""" """
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
...@@ -368,14 +434,18 @@ class GetSharedDataRequest(rpc.Request): ...@@ -368,14 +434,18 @@ class GetSharedDataRequest(rpc.Request):
if server_state.keep_alive: if server_state.keep_alive:
if name not in kv_store.orig_data: if name not in kv_store.orig_data:
continue continue
meta[name] = (F.shape(data), meta[name] = (
F.shape(data),
F.reverse_data_type_dict[F.dtype(data)], F.reverse_data_type_dict[F.dtype(data)],
kv_store.part_policy[name].policy_str) kv_store.part_policy[name].policy_str,
)
res = GetSharedDataResponse(meta) res = GetSharedDataResponse(meta)
return res return res
GET_PART_SHAPE = 901238 GET_PART_SHAPE = 901238
class GetPartShapeResponse(rpc.Response): class GetPartShapeResponse(rpc.Response):
"""Send the partitioned data shape back to client. """Send the partitioned data shape back to client.
...@@ -384,6 +454,7 @@ class GetPartShapeResponse(rpc.Response): ...@@ -384,6 +454,7 @@ class GetPartShapeResponse(rpc.Response):
shape : tuple shape : tuple
shape of tensor shape of tensor
""" """
def __init__(self, shape): def __init__(self, shape):
self.shape = shape self.shape = shape
...@@ -397,6 +468,7 @@ class GetPartShapeResponse(rpc.Response): ...@@ -397,6 +468,7 @@ class GetPartShapeResponse(rpc.Response):
else: else:
self.shape = state self.shape = state
class GetPartShapeRequest(rpc.Request): class GetPartShapeRequest(rpc.Request):
"""Send data name to get the partitioned data shape from server. """Send data name to get the partitioned data shape from server.
...@@ -405,6 +477,7 @@ class GetPartShapeRequest(rpc.Request): ...@@ -405,6 +477,7 @@ class GetPartShapeRequest(rpc.Request):
name : str name : str
data name data name
""" """
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
...@@ -417,18 +490,23 @@ class GetPartShapeRequest(rpc.Request): ...@@ -417,18 +490,23 @@ class GetPartShapeRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
if self.name not in kv_store.data_store: if self.name not in kv_store.data_store:
raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name) raise RuntimeError(
"KVServer Cannot find data tensor with name: %s" % self.name
)
data_shape = F.shape(kv_store.data_store[self.name]) data_shape = F.shape(kv_store.data_store[self.name])
res = GetPartShapeResponse(data_shape) res = GetPartShapeResponse(data_shape)
return res return res
SEND_META_TO_BACKUP = 901239 SEND_META_TO_BACKUP = 901239
SEND_META_TO_BACKUP_MSG = "Send_Meta_TO_Backup" SEND_META_TO_BACKUP_MSG = "Send_Meta_TO_Backup"
class SendMetaToBackupResponse(rpc.Response): class SendMetaToBackupResponse(rpc.Response):
"""Send a confirmation signal (just a short string message) """Send a confirmation signal (just a short string message)
of SendMetaToBackupRequest to client. of SendMetaToBackupRequest to client.
""" """
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
...@@ -438,6 +516,7 @@ class SendMetaToBackupResponse(rpc.Response): ...@@ -438,6 +516,7 @@ class SendMetaToBackupResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self.msg = state self.msg = state
class SendMetaToBackupRequest(rpc.Request): class SendMetaToBackupRequest(rpc.Request):
"""Send meta data to backup server and backup server """Send meta data to backup server and backup server
will use this meta data to read shared-memory tensor. will use this meta data to read shared-memory tensor.
...@@ -457,7 +536,10 @@ class SendMetaToBackupRequest(rpc.Request): ...@@ -457,7 +536,10 @@ class SendMetaToBackupRequest(rpc.Request):
push_handler : callable push_handler : callable
The callback function when data is pushed to kvstore. The callback function when data is pushed to kvstore.
""" """
def __init__(self, name, dtype, shape, policy_str, pull_handler, push_handler):
def __init__(
self, name, dtype, shape, policy_str, pull_handler, push_handler
):
self.name = name self.name = name
self.dtype = dtype self.dtype = dtype
self.shape = shape self.shape = shape
...@@ -466,39 +548,65 @@ class SendMetaToBackupRequest(rpc.Request): ...@@ -466,39 +548,65 @@ class SendMetaToBackupRequest(rpc.Request):
self.push_handler = push_handler self.push_handler = push_handler
def __getstate__(self): def __getstate__(self):
return self.name, self.dtype, self.shape, self.policy_str, self.pull_handler, \ return (
self.push_handler self.name,
self.dtype,
self.shape,
self.policy_str,
self.pull_handler,
self.push_handler,
)
def __setstate__(self, state): def __setstate__(self, state):
self.name, self.dtype, self.shape, self.policy_str, self.pull_handler, \ (
self.push_handler = state self.name,
self.dtype,
self.shape,
self.policy_str,
self.pull_handler,
self.push_handler,
) = state
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
assert kv_store.is_backup_server() assert kv_store.is_backup_server()
if self.name not in kv_store.data_store: if self.name not in kv_store.data_store:
shared_data = empty_shared_mem(self.name+'-kvdata-', False, self.shape, self.dtype) shared_data = empty_shared_mem(
self.name + "-kvdata-", False, self.shape, self.dtype
)
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
kv_store.data_store[self.name] = F.zerocopy_from_dlpack(dlpack) kv_store.data_store[self.name] = F.zerocopy_from_dlpack(dlpack)
kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str) kv_store.part_policy[self.name] = kv_store.find_policy(
self.policy_str
)
kv_store.pull_handlers[self.name] = self.pull_handler kv_store.pull_handlers[self.name] = self.pull_handler
kv_store.push_handlers[self.name] = self.push_handler kv_store.push_handlers[self.name] = self.push_handler
else: else:
assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(self.shape) assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(
assert F.reverse_data_type_dict[F.dtype(kv_store.data_store[self.name])] == self.dtype self.shape
)
assert (
F.reverse_data_type_dict[
F.dtype(kv_store.data_store[self.name])
]
== self.dtype
)
assert kv_store.part_policy[self.name].policy_str == self.policy_str assert kv_store.part_policy[self.name].policy_str == self.policy_str
assert kv_store.pull_handlers[self.name] == self.pull_handler assert kv_store.pull_handlers[self.name] == self.pull_handler
assert kv_store.push_handlers[self.name] == self.push_handler assert kv_store.push_handlers[self.name] == self.push_handler
res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG) res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG)
return res return res
DELETE_DATA = 901240 DELETE_DATA = 901240
DELETE_MSG = "Delete_Data" DELETE_MSG = "Delete_Data"
class DeleteDataResponse(rpc.Response): class DeleteDataResponse(rpc.Response):
"""Send a confirmation signal (just a short string message) """Send a confirmation signal (just a short string message)
of DeleteDataRequest to client. of DeleteDataRequest to client.
""" """
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
...@@ -508,6 +616,7 @@ class DeleteDataResponse(rpc.Response): ...@@ -508,6 +616,7 @@ class DeleteDataResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self.msg = state self.msg = state
class DeleteDataRequest(rpc.Request): class DeleteDataRequest(rpc.Request):
"""Send message to server to delete data tensor """Send message to server to delete data tensor
...@@ -516,6 +625,7 @@ class DeleteDataRequest(rpc.Request): ...@@ -516,6 +625,7 @@ class DeleteDataRequest(rpc.Request):
name : str name : str
data name data name
""" """
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
...@@ -535,11 +645,13 @@ class DeleteDataRequest(rpc.Request): ...@@ -535,11 +645,13 @@ class DeleteDataRequest(rpc.Request):
res = DeleteDataResponse(DELETE_MSG) res = DeleteDataResponse(DELETE_MSG)
return res return res
COUNT_LOCAL_NONZERO = 901241 COUNT_LOCAL_NONZERO = 901241
class CountLocalNonzeroResponse(rpc.Response): class CountLocalNonzeroResponse(rpc.Response):
"""Send the number of nonzero value in local data """Send the number of nonzero value in local data"""
"""
def __init__(self, num_local_nonzero): def __init__(self, num_local_nonzero):
self.num_local_nonzero = num_local_nonzero self.num_local_nonzero = num_local_nonzero
...@@ -549,6 +661,7 @@ class CountLocalNonzeroResponse(rpc.Response): ...@@ -549,6 +661,7 @@ class CountLocalNonzeroResponse(rpc.Response):
def __setstate__(self, state): def __setstate__(self, state):
self.num_local_nonzero = state self.num_local_nonzero = state
class CountLocalNonzeroRequest(rpc.Request): class CountLocalNonzeroRequest(rpc.Request):
"""Send data name to server to count local nonzero value """Send data name to server to count local nonzero value
Parameters Parameters
...@@ -556,6 +669,7 @@ class CountLocalNonzeroRequest(rpc.Request): ...@@ -556,6 +669,7 @@ class CountLocalNonzeroRequest(rpc.Request):
name : str name : str
data name data name
""" """
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
...@@ -571,8 +685,10 @@ class CountLocalNonzeroRequest(rpc.Request): ...@@ -571,8 +685,10 @@ class CountLocalNonzeroRequest(rpc.Request):
res = CountLocalNonzeroResponse(num_local_nonzero) res = CountLocalNonzeroResponse(num_local_nonzero)
return res return res
############################ KVServer ############################### ############################ KVServer ###############################
def default_push_handler(target, name, id_tensor, data_tensor): def default_push_handler(target, name, id_tensor, data_tensor):
"""Default handler for PUSH message. """Default handler for PUSH message.
...@@ -592,6 +708,7 @@ def default_push_handler(target, name, id_tensor, data_tensor): ...@@ -592,6 +708,7 @@ def default_push_handler(target, name, id_tensor, data_tensor):
# TODO(chao): support Tensorflow backend # TODO(chao): support Tensorflow backend
target[name][id_tensor] = data_tensor target[name][id_tensor] = data_tensor
def default_pull_handler(target, name, id_tensor): def default_pull_handler(target, name, id_tensor):
"""Default handler for PULL operation. """Default handler for PULL operation.
...@@ -614,6 +731,7 @@ def default_pull_handler(target, name, id_tensor): ...@@ -614,6 +731,7 @@ def default_pull_handler(target, name, id_tensor):
# TODO(chao): support Tensorflow backend # TODO(chao): support Tensorflow backend
return target[name][id_tensor] return target[name][id_tensor]
class KVServer(object): class KVServer(object):
"""KVServer is a lightweight key-value store service for DGL distributed training. """KVServer is a lightweight key-value store service for DGL distributed training.
...@@ -636,45 +754,50 @@ class KVServer(object): ...@@ -636,45 +754,50 @@ class KVServer(object):
num_clients : int num_clients : int
Total number of KVClients that will be connected to the KVServer. Total number of KVClients that will be connected to the KVServer.
""" """
def __init__(self, server_id, ip_config, num_servers, num_clients): def __init__(self, server_id, ip_config, num_servers, num_clients):
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id assert server_id >= 0, (
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers "server_id (%d) cannot be a negative number." % server_id
assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config )
assert num_clients >= 0, 'num_clients (%d) cannot be a negative number.' % num_clients assert num_servers > 0, (
"num_servers (%d) must be a positive number." % num_servers
)
assert os.path.exists(ip_config), "Cannot open file: %s" % ip_config
assert num_clients >= 0, (
"num_clients (%d) cannot be a negative number." % num_clients
)
# Register services on server # Register services on server
rpc.register_service(KVSTORE_PULL, rpc.register_service(KVSTORE_PULL, PullRequest, PullResponse)
PullRequest, rpc.register_service(KVSTORE_PUSH, PushRequest, None)
PullResponse) rpc.register_service(INIT_DATA, InitDataRequest, InitDataResponse)
rpc.register_service(KVSTORE_PUSH, rpc.register_service(BARRIER, BarrierRequest, BarrierResponse)
PushRequest, rpc.register_service(
None) REGISTER_PUSH,
rpc.register_service(INIT_DATA,
InitDataRequest,
InitDataResponse)
rpc.register_service(BARRIER,
BarrierRequest,
BarrierResponse)
rpc.register_service(REGISTER_PUSH,
RegisterPushHandlerRequest, RegisterPushHandlerRequest,
RegisterPushHandlerResponse) RegisterPushHandlerResponse,
rpc.register_service(REGISTER_PULL, )
rpc.register_service(
REGISTER_PULL,
RegisterPullHandlerRequest, RegisterPullHandlerRequest,
RegisterPullHandlerResponse) RegisterPullHandlerResponse,
rpc.register_service(GET_SHARED, )
GetSharedDataRequest, rpc.register_service(
GetSharedDataResponse) GET_SHARED, GetSharedDataRequest, GetSharedDataResponse
rpc.register_service(GET_PART_SHAPE, )
GetPartShapeRequest, rpc.register_service(
GetPartShapeResponse) GET_PART_SHAPE, GetPartShapeRequest, GetPartShapeResponse
rpc.register_service(SEND_META_TO_BACKUP, )
rpc.register_service(
SEND_META_TO_BACKUP,
SendMetaToBackupRequest, SendMetaToBackupRequest,
SendMetaToBackupResponse) SendMetaToBackupResponse,
rpc.register_service(DELETE_DATA, )
DeleteDataRequest, rpc.register_service(DELETE_DATA, DeleteDataRequest, DeleteDataResponse)
DeleteDataResponse) rpc.register_service(
rpc.register_service(COUNT_LOCAL_NONZERO, COUNT_LOCAL_NONZERO,
CountLocalNonzeroRequest, CountLocalNonzeroRequest,
CountLocalNonzeroResponse) CountLocalNonzeroResponse,
)
# Store the tensor data with specified data name # Store the tensor data with specified data name
self._data_store = {} self._data_store = {}
# Store original tensor data names when instantiating DistGraphServer # Store original tensor data names when instantiating DistGraphServer
...@@ -685,9 +808,11 @@ class KVServer(object): ...@@ -685,9 +808,11 @@ class KVServer(object):
# Basic information # Basic information
self._server_id = server_id self._server_id = server_id
self._server_namebook = rpc.read_ip_config(ip_config, num_servers) self._server_namebook = rpc.read_ip_config(ip_config, num_servers)
assert server_id in self._server_namebook, \ assert (
'Trying to start server {}, but there are {} servers in the config file'.format( server_id in self._server_namebook
server_id, len(self._server_namebook)) ), "Trying to start server {}, but there are {} servers in the config file".format(
server_id, len(self._server_namebook)
)
self._machine_id = self._server_namebook[server_id][0] self._machine_id = self._server_namebook[server_id][0]
self._group_count = self._server_namebook[server_id][3] self._group_count = self._server_namebook[server_id][3]
# We assume partition_id is equal to machine_id # We assume partition_id is equal to machine_id
...@@ -749,8 +874,7 @@ class KVServer(object): ...@@ -749,8 +874,7 @@ class KVServer(object):
return self._pull_handlers return self._pull_handlers
def is_backup_server(self): def is_backup_server(self):
"""Return True if current server is a backup server. """Return True if current server is a backup server."""
"""
if self._server_id % self._group_count == 0: if self._server_id % self._group_count == 0:
return False return False
return True return True
...@@ -778,22 +902,26 @@ class KVServer(object): ...@@ -778,22 +902,26 @@ class KVServer(object):
If the data_tensor is None, KVServer will If the data_tensor is None, KVServer will
read shared-memory when client invoking get_shared_data(). read shared-memory when client invoking get_shared_data().
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, "name cannot be empty."
if name in self._data_store: if name in self._data_store:
raise RuntimeError("Data %s has already exists!" % name) raise RuntimeError("Data %s has already exists!" % name)
self._part_policy[name] = self.find_policy(policy_str) self._part_policy[name] = self.find_policy(policy_str)
if data_tensor is not None: # Create shared-tensor if data_tensor is not None: # Create shared-tensor
data_type = F.reverse_data_type_dict[F.dtype(data_tensor)] data_type = F.reverse_data_type_dict[F.dtype(data_tensor)]
shared_data = empty_shared_mem(name+'-kvdata-', True, data_tensor.shape, data_type) shared_data = empty_shared_mem(
name + "-kvdata-", True, data_tensor.shape, data_type
)
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack) self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
rpc.copy_data_to_shared_memory(self._data_store[name], data_tensor) rpc.copy_data_to_shared_memory(self._data_store[name], data_tensor)
assert self._part_policy[name].get_part_size() == data_tensor.shape[0], \ assert (
'kvserver expect partition {} for {} has {} rows, but gets {} rows'.format( self._part_policy[name].get_part_size() == data_tensor.shape[0]
), "kvserver expect partition {} for {} has {} rows, but gets {} rows".format(
self._part_policy[name].part_id, self._part_policy[name].part_id,
policy_str, policy_str,
self._part_policy[name].get_part_size(), self._part_policy[name].get_part_size(),
data_tensor.shape[0]) data_tensor.shape[0],
)
self._pull_handlers[name] = default_pull_handler self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler self._push_handlers[name] = default_push_handler
...@@ -808,7 +936,9 @@ class KVServer(object): ...@@ -808,7 +936,9 @@ class KVServer(object):
for policy in self._policy_set: for policy in self._policy_set:
if policy_str == policy.policy_str: if policy_str == policy.policy_str:
return policy return policy
raise RuntimeError("Cannot find policy_str: %s from kvserver." % policy_str) raise RuntimeError(
"Cannot find policy_str: %s from kvserver." % policy_str
)
def count_local_nonzero(self, name): def count_local_nonzero(self, name):
"""Count nonzero in local data """Count nonzero in local data
...@@ -823,13 +953,15 @@ class KVServer(object): ...@@ -823,13 +953,15 @@ class KVServer(object):
int int
the number of nonzero in local data. the number of nonzero in local data.
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, "name cannot be empty."
if name not in self._data_store: if name not in self._data_store:
raise RuntimeError("Data %s has not be created!" % name) raise RuntimeError("Data %s has not be created!" % name)
return F.count_nonzero(self._data_store[name]) return F.count_nonzero(self._data_store[name])
############################ KVClient ############################### ############################ KVClient ###############################
class KVClient(object): class KVClient(object):
"""KVClient is used to push/pull data to/from KVServer. If the """KVClient is used to push/pull data to/from KVServer. If the
target kvclient and kvserver are in the same machine, they can target kvclient and kvserver are in the same machine, they can
...@@ -849,45 +981,47 @@ class KVClient(object): ...@@ -849,45 +981,47 @@ class KVClient(object):
role : str role : str
We can set different role for kvstore. We can set different role for kvstore.
""" """
def __init__(self, ip_config, num_servers, role='default'):
assert rpc.get_rank() != -1, \ def __init__(self, ip_config, num_servers, role="default"):
'Please invoke rpc.connect_to_server() before creating KVClient.' assert (
assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config rpc.get_rank() != -1
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers ), "Please invoke rpc.connect_to_server() before creating KVClient."
assert os.path.exists(ip_config), "Cannot open file: %s" % ip_config
assert num_servers > 0, (
"num_servers (%d) must be a positive number." % num_servers
)
# Register services on client # Register services on client
rpc.register_service(KVSTORE_PULL, rpc.register_service(KVSTORE_PULL, PullRequest, PullResponse)
PullRequest, rpc.register_service(KVSTORE_PUSH, PushRequest, None)
PullResponse) rpc.register_service(INIT_DATA, InitDataRequest, InitDataResponse)
rpc.register_service(KVSTORE_PUSH, rpc.register_service(BARRIER, BarrierRequest, BarrierResponse)
PushRequest, rpc.register_service(
None) REGISTER_PUSH,
rpc.register_service(INIT_DATA,
InitDataRequest,
InitDataResponse)
rpc.register_service(BARRIER,
BarrierRequest,
BarrierResponse)
rpc.register_service(REGISTER_PUSH,
RegisterPushHandlerRequest, RegisterPushHandlerRequest,
RegisterPushHandlerResponse) RegisterPushHandlerResponse,
rpc.register_service(REGISTER_PULL, )
rpc.register_service(
REGISTER_PULL,
RegisterPullHandlerRequest, RegisterPullHandlerRequest,
RegisterPullHandlerResponse) RegisterPullHandlerResponse,
rpc.register_service(GET_SHARED, )
GetSharedDataRequest, rpc.register_service(
GetSharedDataResponse) GET_SHARED, GetSharedDataRequest, GetSharedDataResponse
rpc.register_service(GET_PART_SHAPE, )
GetPartShapeRequest, rpc.register_service(
GetPartShapeResponse) GET_PART_SHAPE, GetPartShapeRequest, GetPartShapeResponse
rpc.register_service(SEND_META_TO_BACKUP, )
rpc.register_service(
SEND_META_TO_BACKUP,
SendMetaToBackupRequest, SendMetaToBackupRequest,
SendMetaToBackupResponse) SendMetaToBackupResponse,
rpc.register_service(DELETE_DATA, )
DeleteDataRequest, rpc.register_service(DELETE_DATA, DeleteDataRequest, DeleteDataResponse)
DeleteDataResponse) rpc.register_service(
rpc.register_service(COUNT_LOCAL_NONZERO, COUNT_LOCAL_NONZERO,
CountLocalNonzeroRequest, CountLocalNonzeroRequest,
CountLocalNonzeroResponse) CountLocalNonzeroResponse,
)
# Store the tensor data with specified data name # Store the tensor data with specified data name
self._data_store = {} self._data_store = {}
# Store the partition information with specified data name # Store the partition information with specified data name
...@@ -1015,7 +1149,9 @@ class KVClient(object): ...@@ -1015,7 +1149,9 @@ class KVClient(object):
self._pull_handlers[name] = func self._pull_handlers[name] = func
self.barrier() self.barrier()
def init_data(self, name, shape, dtype, part_policy, init_func, is_gdata=True): def init_data(
self, name, shape, dtype, part_policy, init_func, is_gdata=True
):
"""Send message to kvserver to initialize new data tensor and mapping this """Send message to kvserver to initialize new data tensor and mapping this
data from server side to client side. data from server side to client side.
...@@ -1034,9 +1170,11 @@ class KVClient(object): ...@@ -1034,9 +1170,11 @@ class KVClient(object):
is_gdata : bool is_gdata : bool
Whether the created tensor is a ndata/edata or not. Whether the created tensor is a ndata/edata or not.
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, "name cannot be empty."
assert len(shape) > 0, 'shape cannot be empty' assert len(shape) > 0, "shape cannot be empty"
assert name not in self._data_name_list, 'data name: %s already exists.' % name assert name not in self._data_name_list, (
"data name: %s already exists." % name
)
self.barrier() self.barrier()
shape = list(shape) shape = list(shape)
...@@ -1044,11 +1182,13 @@ class KVClient(object): ...@@ -1044,11 +1182,13 @@ class KVClient(object):
# The servers may handle the duplicated initializations. # The servers may handle the duplicated initializations.
part_shape = shape.copy() part_shape = shape.copy()
part_shape[0] = part_policy.get_part_size() part_shape[0] = part_policy.get_part_size()
request = InitDataRequest(name, request = InitDataRequest(
name,
tuple(part_shape), tuple(part_shape),
F.reverse_data_type_dict[dtype], F.reverse_data_type_dict[dtype],
part_policy.policy_str, part_policy.policy_str,
init_func) init_func,
)
# The request is sent to the servers in one group, which are on the same machine. # The request is sent to the servers in one group, which are on the same machine.
for n in range(self._group_count): for n in range(self._group_count):
server_id = part_policy.part_id * self._group_count + n server_id = part_policy.part_id * self._group_count + n
...@@ -1069,8 +1209,12 @@ class KVClient(object): ...@@ -1069,8 +1209,12 @@ class KVClient(object):
raise RuntimeError("Data shape %s has already exists!" % name) raise RuntimeError("Data shape %s has already exists!" % name)
self._part_policy[name] = part_policy self._part_policy[name] = part_policy
self._all_possible_part_policy[part_policy.policy_str] = part_policy self._all_possible_part_policy[part_policy.policy_str] = part_policy
shared_data = empty_shared_mem(name+'-kvdata-', False, \ shared_data = empty_shared_mem(
local_shape, F.reverse_data_type_dict[dtype]) name + "-kvdata-",
False,
local_shape,
F.reverse_data_type_dict[dtype],
)
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack) self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._data_name_list.add(name) self._data_name_list.add(name)
...@@ -1081,16 +1225,20 @@ class KVClient(object): ...@@ -1081,16 +1225,20 @@ class KVClient(object):
self._push_handlers[name] = default_push_handler self._push_handlers[name] = default_push_handler
# Now we need to tell the backup server the new tensor. # Now we need to tell the backup server the new tensor.
request = SendMetaToBackupRequest(name, F.reverse_data_type_dict[dtype], request = SendMetaToBackupRequest(
part_shape, part_policy.policy_str, name,
F.reverse_data_type_dict[dtype],
part_shape,
part_policy.policy_str,
self._pull_handlers[name], self._pull_handlers[name],
self._push_handlers[name]) self._push_handlers[name],
)
# send request to all the backup server nodes # send request to all the backup server nodes
for i in range(self._group_count-1): for i in range(self._group_count - 1):
server_id = self._machine_id * self._group_count + i + 1 server_id = self._machine_id * self._group_count + i + 1
rpc.send_request(server_id, request) rpc.send_request(server_id, request)
# recv response from all the backup server nodes # recv response from all the backup server nodes
for _ in range(self._group_count-1): for _ in range(self._group_count - 1):
response = rpc.recv_response() response = rpc.recv_response()
assert response.msg == SEND_META_TO_BACKUP_MSG assert response.msg == SEND_META_TO_BACKUP_MSG
self.barrier() self.barrier()
...@@ -1103,8 +1251,8 @@ class KVClient(object): ...@@ -1103,8 +1251,8 @@ class KVClient(object):
name : str name : str
data name data name
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, "name cannot be empty."
assert name in self._data_name_list, 'data name: %s not exists.' % name assert name in self._data_name_list, "data name: %s not exists." % name
self.barrier() self.barrier()
part_policy = self._part_policy[name] part_policy = self._part_policy[name]
...@@ -1154,10 +1302,14 @@ class KVClient(object): ...@@ -1154,10 +1302,14 @@ class KVClient(object):
if name not in self._data_name_list: if name not in self._data_name_list:
shape, dtype, policy_str = meta shape, dtype, policy_str = meta
assert policy_str in self._all_possible_part_policy assert policy_str in self._all_possible_part_policy
shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype) shared_data = empty_shared_mem(
name + "-kvdata-", False, shape, dtype
)
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack) self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._part_policy[name] = self._all_possible_part_policy[policy_str] self._part_policy[name] = self._all_possible_part_policy[
policy_str
]
self._pull_handlers[name] = default_pull_handler self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler self._push_handlers[name] = default_push_handler
# Get full data shape across servers # Get full data shape across servers
...@@ -1179,15 +1331,20 @@ class KVClient(object): ...@@ -1179,15 +1331,20 @@ class KVClient(object):
# Send meta data to backup servers # Send meta data to backup servers
for name, meta in response.meta.items(): for name, meta in response.meta.items():
shape, dtype, policy_str = meta shape, dtype, policy_str = meta
request = SendMetaToBackupRequest(name, dtype, shape, policy_str, request = SendMetaToBackupRequest(
name,
dtype,
shape,
policy_str,
self._pull_handlers[name], self._pull_handlers[name],
self._push_handlers[name]) self._push_handlers[name],
)
# send request to all the backup server nodes # send request to all the backup server nodes
for i in range(self._group_count-1): for i in range(self._group_count - 1):
server_id = self._machine_id * self._group_count + i + 1 server_id = self._machine_id * self._group_count + i + 1
rpc.send_request(server_id, request) rpc.send_request(server_id, request)
# recv response from all the backup server nodes # recv response from all the backup server nodes
for _ in range(self._group_count-1): for _ in range(self._group_count - 1):
response = rpc.recv_response() response = rpc.recv_response()
assert response.msg == SEND_META_TO_BACKUP_MSG assert response.msg == SEND_META_TO_BACKUP_MSG
self._data_name_list.add(name) self._data_name_list.add(name)
...@@ -1205,9 +1362,8 @@ class KVClient(object): ...@@ -1205,9 +1362,8 @@ class KVClient(object):
return list(self._data_name_list) return list(self._data_name_list)
def get_data_meta(self, name): def get_data_meta(self, name):
"""Get meta data (data_type, data_shape, partition_policy) """Get meta data (data_type, data_shape, partition_policy)"""
""" assert len(name) > 0, "name cannot be empty."
assert len(name) > 0, 'name cannot be empty.'
data_type = F.dtype(self._data_store[name]) data_type = F.dtype(self._data_store[name])
data_shape = self._full_data_shape[name] data_shape = self._full_data_shape[name]
part_policy = self._part_policy[name] part_policy = self._part_policy[name]
...@@ -1222,16 +1378,15 @@ class KVClient(object): ...@@ -1222,16 +1378,15 @@ class KVClient(object):
id_tensor : tensor id_tensor : tensor
a vector storing the global data ID a vector storing the global data ID
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor) id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor() id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, 'ID must be a vector.' assert F.ndim(id_tensor) == 1, "ID must be a vector."
# partition data # partition data
machine_id = self._part_policy[name].to_partid(id_tensor) machine_id = self._part_policy[name].to_partid(id_tensor)
return machine_id return machine_id
def push(self, name, id_tensor, data_tensor): def push(self, name, id_tensor, data_tensor):
"""Push data to KVServer. """Push data to KVServer.
...@@ -1246,12 +1401,13 @@ class KVClient(object): ...@@ -1246,12 +1401,13 @@ class KVClient(object):
data_tensor : tensor data_tensor : tensor
a tensor with the same row size of data ID a tensor with the same row size of data ID
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor) id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor() id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, 'ID must be a vector.' assert F.ndim(id_tensor) == 1, "ID must be a vector."
assert F.shape(id_tensor)[0] == F.shape(data_tensor)[0], \ assert (
'The data must has the same row size with ID.' F.shape(id_tensor)[0] == F.shape(data_tensor)[0]
), "The data must has the same row size with ID."
# partition data # partition data
machine_id = self._part_policy[name].to_partid(id_tensor) machine_id = self._part_policy[name].to_partid(id_tensor)
# sort index by machine id # sort index by machine id
...@@ -1279,7 +1435,9 @@ class KVClient(object): ...@@ -1279,7 +1435,9 @@ class KVClient(object):
rpc.send_request_to_machine(machine_idx, request) rpc.send_request_to_machine(machine_idx, request)
start += count[idx] start += count[idx]
if local_id is not None: # local push if local_id is not None: # local push
self._push_handlers[name](self._data_store, name, local_id, local_data) self._push_handlers[name](
self._data_store, name, local_id, local_data
)
def pull(self, name, id_tensor): def pull(self, name, id_tensor):
"""Pull message from KVServer. """Pull message from KVServer.
...@@ -1296,19 +1454,24 @@ class KVClient(object): ...@@ -1296,19 +1454,24 @@ class KVClient(object):
tensor tensor
a data tensor with the same row size of id_tensor. a data tensor with the same row size of id_tensor.
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor) id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor() id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, 'ID must be a vector.' assert F.ndim(id_tensor) == 1, "ID must be a vector."
if self._pull_handlers[name] is default_pull_handler: # Use fast-pull if self._pull_handlers[name] is default_pull_handler: # Use fast-pull
part_id = self._part_policy[name].to_partid(id_tensor) part_id = self._part_policy[name].to_partid(id_tensor)
return rpc.fast_pull(name, id_tensor, part_id, KVSTORE_PULL, return rpc.fast_pull(
name,
id_tensor,
part_id,
KVSTORE_PULL,
self._machine_count, self._machine_count,
self._group_count, self._group_count,
self._machine_id, self._machine_id,
self._client_id, self._client_id,
self._data_store[name], self._data_store[name],
self._part_policy[name]) self._part_policy[name],
)
else: else:
# partition data # partition data
machine_id = self._part_policy[name].to_partid(id_tensor) machine_id = self._part_policy[name].to_partid(id_tensor)
...@@ -1316,7 +1479,9 @@ class KVClient(object): ...@@ -1316,7 +1479,9 @@ class KVClient(object):
sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id))) sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id)))
back_sorted_id = F.tensor(np.argsort(F.asnumpy(sorted_id))) back_sorted_id = F.tensor(np.argsort(F.asnumpy(sorted_id)))
id_tensor = id_tensor[sorted_id] id_tensor = id_tensor[sorted_id]
machine, count = np.unique(F.asnumpy(machine_id), return_counts=True) machine, count = np.unique(
F.asnumpy(machine_id), return_counts=True
)
# pull data from server by order # pull data from server by order
start = 0 start = 0
pull_count = 0 pull_count = 0
...@@ -1338,7 +1503,9 @@ class KVClient(object): ...@@ -1338,7 +1503,9 @@ class KVClient(object):
# recv response # recv response
response_list = [] response_list = []
if local_id is not None: # local pull if local_id is not None: # local pull
local_data = self._pull_handlers[name](self._data_store, name, local_id) local_data = self._pull_handlers[name](
self._data_store, name, local_id
)
server_id = self._main_server_id server_id = self._main_server_id
local_response = PullResponse(server_id, local_data) local_response = PullResponse(server_id, local_data)
response_list.append(local_response) response_list.append(local_response)
...@@ -1348,21 +1515,22 @@ class KVClient(object): ...@@ -1348,21 +1515,22 @@ class KVClient(object):
response_list.append(remote_response) response_list.append(remote_response)
# sort response by server_id and concat tensor # sort response by server_id and concat tensor
response_list.sort(key=self._take_id) response_list.sort(key=self._take_id)
data_tensor = F.cat(seq=[response.data_tensor for response in response_list], dim=0) data_tensor = F.cat(
return data_tensor[back_sorted_id] # return data with original index order seq=[response.data_tensor for response in response_list], dim=0
)
return data_tensor[
back_sorted_id
] # return data with original index order
def union(self, operand1_name, operand2_name, output_name): def union(self, operand1_name, operand2_name, output_name):
"""Compute the union of two mask arrays in the KVStore. """Compute the union of two mask arrays in the KVStore."""
"""
# Each trainer computes its own result from its local storage. # Each trainer computes its own result from its local storage.
self._data_store[output_name][:] = ( self._data_store[output_name][:] = (
self._data_store[operand1_name] | self._data_store[operand1_name] | self._data_store[operand2_name]
self._data_store[operand2_name]
) )
def _take_id(self, elem): def _take_id(self, elem):
"""Used by sort response list """Used by sort response list"""
"""
return elem.server_id return elem.server_id
def count_nonzero(self, name): def count_nonzero(self, name):
...@@ -1382,8 +1550,11 @@ class KVClient(object): ...@@ -1382,8 +1550,11 @@ class KVClient(object):
pull_count = 0 pull_count = 0
for machine_id in range(self._machine_count): for machine_id in range(self._machine_count):
if machine_id == self._machine_id: if machine_id == self._machine_id:
local_id = F.tensor(np.arange(self._part_policy[name].get_part_size(), local_id = F.tensor(
dtype=np.int64)) np.arange(
self._part_policy[name].get_part_size(), dtype=np.int64
)
)
total += F.count_nonzero(self._data_store[name][local_id]) total += F.count_nonzero(self._data_store[name][local_id])
else: else:
request = CountLocalNonzeroRequest(name) request = CountLocalNonzeroRequest(name)
...@@ -1405,22 +1576,26 @@ class KVClient(object): ...@@ -1405,22 +1576,26 @@ class KVClient(object):
""" """
return self._data_store return self._data_store
KVCLIENT = None KVCLIENT = None
def init_kvstore(ip_config, num_servers, role): def init_kvstore(ip_config, num_servers, role):
"""initialize KVStore""" """initialize KVStore"""
global KVCLIENT global KVCLIENT
if KVCLIENT is None: if KVCLIENT is None:
if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone': if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone":
KVCLIENT = SA_KVClient() KVCLIENT = SA_KVClient()
else: else:
KVCLIENT = KVClient(ip_config, num_servers, role) KVCLIENT = KVClient(ip_config, num_servers, role)
def close_kvstore(): def close_kvstore():
"""Close the current KVClient""" """Close the current KVClient"""
global KVCLIENT global KVCLIENT
KVCLIENT = None KVCLIENT = None
def get_kvstore(): def get_kvstore():
"""get the KVClient""" """get the KVClient"""
return KVCLIENT return KVCLIENT
"""Define sparse embedding and optimizer.""" """Define sparse embedding and optimizer."""
import torch as th import torch as th
from .... import backend as F
from .... import utils from .... import backend as F, utils
from ...dist_tensor import DistTensor from ...dist_tensor import DistTensor
class DistEmbedding: class DistEmbedding:
'''Distributed node embeddings. """Distributed node embeddings.
DGL provides a distributed embedding to support models that require learnable embeddings. DGL provides a distributed embedding to support models that require learnable embeddings.
DGL's distributed embeddings are mainly used for learning node embeddings of graph models. DGL's distributed embeddings are mainly used for learning node embeddings of graph models.
...@@ -63,11 +64,23 @@ class DistEmbedding: ...@@ -63,11 +64,23 @@ class DistEmbedding:
the forward computation, users have to invoke the forward computation, users have to invoke
py:meth:`~dgl.distributed.optim.SparseAdagrad.step` afterwards. Otherwise, there will be py:meth:`~dgl.distributed.optim.SparseAdagrad.step` afterwards. Otherwise, there will be
some memory leak. some memory leak.
''' """
def __init__(self, num_embeddings, embedding_dim, name=None,
init_func=None, part_policy=None): def __init__(
self._tensor = DistTensor((num_embeddings, embedding_dim), F.float32, name, self,
init_func=init_func, part_policy=part_policy) num_embeddings,
embedding_dim,
name=None,
init_func=None,
part_policy=None,
):
self._tensor = DistTensor(
(num_embeddings, embedding_dim),
F.float32,
name,
init_func=init_func,
part_policy=part_policy,
)
self._trace = [] self._trace = []
self._name = name self._name = name
self._num_embeddings = num_embeddings self._num_embeddings = num_embeddings
...@@ -84,7 +97,7 @@ class DistEmbedding: ...@@ -84,7 +97,7 @@ class DistEmbedding:
self._optm_state = None # track optimizer state self._optm_state = None # track optimizer state
self._part_policy = part_policy self._part_policy = part_policy
def __call__(self, idx, device=th.device('cpu')): def __call__(self, idx, device=th.device("cpu")):
""" """
node_ids : th.tensor node_ids : th.tensor
Index of the embeddings to collect. Index of the embeddings to collect.
...@@ -104,8 +117,7 @@ class DistEmbedding: ...@@ -104,8 +117,7 @@ class DistEmbedding:
return emb return emb
def reset_trace(self): def reset_trace(self):
'''Reset the traced data. """Reset the traced data."""
'''
self._trace = [] self._trace = []
@property @property
......
...@@ -10,9 +10,9 @@ import dgl ...@@ -10,9 +10,9 @@ import dgl
from .... import backend as F from .... import backend as F
from ...dist_tensor import DistTensor from ...dist_tensor import DistTensor
from ...graph_partition_book import EDGE_PART_POLICY, NODE_PART_POLICY
from ...nn.pytorch import DistEmbedding from ...nn.pytorch import DistEmbedding
from .utils import alltoall_cpu, alltoallv_cpu from .utils import alltoall_cpu, alltoallv_cpu
from ...graph_partition_book import EDGE_PART_POLICY, NODE_PART_POLICY
EMB_STATES = "emb_states" EMB_STATES = "emb_states"
WORLD_SIZE = "world_size" WORLD_SIZE = "world_size"
......
...@@ -3,40 +3,41 @@ ...@@ -3,40 +3,41 @@
import json import json
import os import os
import time import time
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F
from ..base import NID, EID, NTYPE, ETYPE, DGLError from ..base import DGLError, EID, ETYPE, NID, NTYPE
from ..convert import to_homogeneous from ..convert import to_homogeneous
from ..random import choice as random_choice from ..data.utils import load_graphs, load_tensors, save_graphs, save_tensors
from ..transforms import sort_csr_by_tag, sort_csc_by_tag
from ..data.utils import load_graphs, save_graphs, load_tensors, save_tensors
from ..partition import ( from ..partition import (
get_peak_mem,
metis_partition_assignment, metis_partition_assignment,
partition_graph_with_halo, partition_graph_with_halo,
get_peak_mem,
) )
from ..random import choice as random_choice
from ..transforms import sort_csc_by_tag, sort_csr_by_tag
from .constants import DEFAULT_ETYPE, DEFAULT_NTYPE from .constants import DEFAULT_ETYPE, DEFAULT_NTYPE
from .graph_partition_book import ( from .graph_partition_book import (
RangePartitionBook,
_etype_tuple_to_str,
_etype_str_to_tuple, _etype_str_to_tuple,
_etype_tuple_to_str,
RangePartitionBook,
) )
RESERVED_FIELD_DTYPE = { RESERVED_FIELD_DTYPE = {
'inner_node': F.uint8, # A flag indicates whether the node is inside a partition. "inner_node": F.uint8, # A flag indicates whether the node is inside a partition.
'inner_edge': F.uint8, # A flag indicates whether the edge is inside a partition. "inner_edge": F.uint8, # A flag indicates whether the edge is inside a partition.
NID: F.int64, NID: F.int64,
EID: F.int64, EID: F.int64,
NTYPE: F.int16, NTYPE: F.int16,
# `sort_csr_by_tag` and `sort_csc_by_tag` works on int32/64 only. # `sort_csr_by_tag` and `sort_csc_by_tag` works on int32/64 only.
ETYPE: F.int32 ETYPE: F.int32,
} }
def _format_part_metadata(part_metadata, formatter): def _format_part_metadata(part_metadata, formatter):
'''Format etypes with specified formatter. """Format etypes with specified formatter."""
''' for key in ["edge_map", "etypes"]:
for key in ['edge_map', 'etypes']:
if key not in part_metadata: if key not in part_metadata:
continue continue
orig_data = part_metadata[key] orig_data = part_metadata[key]
...@@ -49,32 +50,36 @@ def _format_part_metadata(part_metadata, formatter): ...@@ -49,32 +50,36 @@ def _format_part_metadata(part_metadata, formatter):
part_metadata[key] = new_data part_metadata[key] = new_data
return part_metadata return part_metadata
def _load_part_config(part_config): def _load_part_config(part_config):
'''Load part config and format. """Load part config and format."""
'''
try: try:
with open(part_config) as f: with open(part_config) as f:
part_metadata = _format_part_metadata(json.load(f), part_metadata = _format_part_metadata(
_etype_str_to_tuple) json.load(f), _etype_str_to_tuple
)
except AssertionError as e: except AssertionError as e:
raise DGLError(f"Failed to load partition config due to {e}. " raise DGLError(
f"Failed to load partition config due to {e}. "
"Probably caused by outdated config. If so, please refer to " "Probably caused by outdated config. If so, please refer to "
"https://github.com/dmlc/dgl/tree/master/tools#change-edge-" "https://github.com/dmlc/dgl/tree/master/tools#change-edge-"
"type-to-canonical-edge-type-for-partition-configuration-json") "type-to-canonical-edge-type-for-partition-configuration-json"
)
return part_metadata return part_metadata
def _dump_part_config(part_config, part_metadata): def _dump_part_config(part_config, part_metadata):
'''Format and dump part config. """Format and dump part config."""
'''
part_metadata = _format_part_metadata(part_metadata, _etype_tuple_to_str) part_metadata = _format_part_metadata(part_metadata, _etype_tuple_to_str)
with open(part_config, 'w') as outfile: with open(part_config, "w") as outfile:
json.dump(part_metadata, outfile, sort_keys=True, indent=4) json.dump(part_metadata, outfile, sort_keys=True, indent=4)
def _save_graphs(filename, g_list, formats=None, sort_etypes=False): def _save_graphs(filename, g_list, formats=None, sort_etypes=False):
'''Preprocess partitions before saving: """Preprocess partitions before saving:
1. format data types. 1. format data types.
2. sort csc/csr by tag. 2. sort csc/csr by tag.
''' """
for g in g_list: for g in g_list:
for k, dtype in RESERVED_FIELD_DTYPE.items(): for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in g.ndata: if k in g.ndata:
...@@ -84,25 +89,36 @@ def _save_graphs(filename, g_list, formats=None, sort_etypes=False): ...@@ -84,25 +89,36 @@ def _save_graphs(filename, g_list, formats=None, sort_etypes=False):
for g in g_list: for g in g_list:
if (not sort_etypes) or (formats is None): if (not sort_etypes) or (formats is None):
continue continue
if 'csr' in formats: if "csr" in formats:
g = sort_csr_by_tag(g, tag=g.edata[ETYPE], tag_type='edge') g = sort_csr_by_tag(g, tag=g.edata[ETYPE], tag_type="edge")
if 'csc' in formats: if "csc" in formats:
g = sort_csc_by_tag(g, tag=g.edata[ETYPE], tag_type='edge') g = sort_csc_by_tag(g, tag=g.edata[ETYPE], tag_type="edge")
save_graphs(filename , g_list, formats=formats) save_graphs(filename, g_list, formats=formats)
def _get_inner_node_mask(graph, ntype_id): def _get_inner_node_mask(graph, ntype_id):
if NTYPE in graph.ndata: if NTYPE in graph.ndata:
dtype = F.dtype(graph.ndata['inner_node']) dtype = F.dtype(graph.ndata["inner_node"])
return graph.ndata['inner_node'] * F.astype(graph.ndata[NTYPE] == ntype_id, dtype) == 1 return (
graph.ndata["inner_node"]
* F.astype(graph.ndata[NTYPE] == ntype_id, dtype)
== 1
)
else: else:
return graph.ndata['inner_node'] == 1 return graph.ndata["inner_node"] == 1
def _get_inner_edge_mask(graph, etype_id): def _get_inner_edge_mask(graph, etype_id):
if ETYPE in graph.edata: if ETYPE in graph.edata:
dtype = F.dtype(graph.edata['inner_edge']) dtype = F.dtype(graph.edata["inner_edge"])
return graph.edata['inner_edge'] * F.astype(graph.edata[ETYPE] == etype_id, dtype) == 1 return (
graph.edata["inner_edge"]
* F.astype(graph.edata[ETYPE] == etype_id, dtype)
== 1
)
else: else:
return graph.edata['inner_edge'] == 1 return graph.edata["inner_edge"] == 1
def _get_part_ranges(id_ranges): def _get_part_ranges(id_ranges):
res = {} res = {}
...@@ -116,11 +132,14 @@ def _get_part_ranges(id_ranges): ...@@ -116,11 +132,14 @@ def _get_part_ranges(id_ranges):
for i, end in enumerate(id_ranges[key]): for i, end in enumerate(id_ranges[key]):
id_ranges[key][i] = [start, end] id_ranges[key][i] = [start, end]
start = end start = end
res[key] = np.concatenate([np.array(l) for l in id_ranges[key]]).reshape(-1, 2) res[key] = np.concatenate(
[np.array(l) for l in id_ranges[key]]
).reshape(-1, 2)
return res return res
def load_partition(part_config, part_id, load_feats=True): def load_partition(part_config, part_id, load_feats=True):
''' Load data of a partition from the data path. """Load data of a partition from the data path.
A partition data includes a graph structure of the partition, a dict of node tensors, A partition data includes a graph structure of the partition, a dict of node tensors,
a dict of edge tensors and some metadata. The partition may contain the HALO nodes, a dict of edge tensors and some metadata. The partition may contain the HALO nodes,
...@@ -158,60 +177,87 @@ def load_partition(part_config, part_id, load_feats=True): ...@@ -158,60 +177,87 @@ def load_partition(part_config, part_id, load_feats=True):
The node types The node types
List[(str, str, str)] List[(str, str, str)]
The edge types The edge types
''' """
config_path = os.path.dirname(part_config) config_path = os.path.dirname(part_config)
relative_to_config = lambda path: os.path.join(config_path, path) relative_to_config = lambda path: os.path.join(config_path, path)
with open(part_config) as conf_f: with open(part_config) as conf_f:
part_metadata = json.load(conf_f) part_metadata = json.load(conf_f)
assert 'part-{}'.format(part_id) in part_metadata, "part-{} does not exist".format(part_id) assert (
part_files = part_metadata['part-{}'.format(part_id)] "part-{}".format(part_id) in part_metadata
assert 'part_graph' in part_files, "the partition does not contain graph structure." ), "part-{} does not exist".format(part_id)
graph = load_graphs(relative_to_config(part_files['part_graph']))[0][0] part_files = part_metadata["part-{}".format(part_id)]
assert (
assert NID in graph.ndata, "the partition graph should contain node mapping to global node ID" "part_graph" in part_files
assert EID in graph.edata, "the partition graph should contain edge mapping to global edge ID" ), "the partition does not contain graph structure."
graph = load_graphs(relative_to_config(part_files["part_graph"]))[0][0]
assert (
NID in graph.ndata
), "the partition graph should contain node mapping to global node ID"
assert (
EID in graph.edata
), "the partition graph should contain edge mapping to global edge ID"
gpb, graph_name, ntypes, etypes = load_partition_book(part_config, part_id) gpb, graph_name, ntypes, etypes = load_partition_book(part_config, part_id)
ntypes_list = list(ntypes.keys()) ntypes_list = list(ntypes.keys())
etypes_list = list(etypes.keys()) etypes_list = list(etypes.keys())
if 'DGL_DIST_DEBUG' in os.environ: if "DGL_DIST_DEBUG" in os.environ:
for ntype in ntypes: for ntype in ntypes:
ntype_id = ntypes[ntype] ntype_id = ntypes[ntype]
# graph.ndata[NID] are global homogeneous node IDs. # graph.ndata[NID] are global homogeneous node IDs.
nids = F.boolean_mask(graph.ndata[NID], _get_inner_node_mask(graph, ntype_id)) nids = F.boolean_mask(
graph.ndata[NID], _get_inner_node_mask(graph, ntype_id)
)
partids1 = gpb.nid2partid(nids) partids1 = gpb.nid2partid(nids)
_, per_type_nids = gpb.map_to_per_ntype(nids) _, per_type_nids = gpb.map_to_per_ntype(nids)
partids2 = gpb.nid2partid(per_type_nids, ntype) partids2 = gpb.nid2partid(per_type_nids, ntype)
assert np.all(F.asnumpy(partids1 == part_id)), \ assert np.all(F.asnumpy(partids1 == part_id)), (
'Unexpected partition IDs are found in the loaded partition ' \ "Unexpected partition IDs are found in the loaded partition "
'while querying via global homogeneous node IDs.' "while querying via global homogeneous node IDs."
assert np.all(F.asnumpy(partids2 == part_id)), \ )
'Unexpected partition IDs are found in the loaded partition ' \ assert np.all(F.asnumpy(partids2 == part_id)), (
'while querying via type-wise node IDs.' "Unexpected partition IDs are found in the loaded partition "
"while querying via type-wise node IDs."
)
for etype in etypes: for etype in etypes:
etype_id = etypes[etype] etype_id = etypes[etype]
# graph.edata[EID] are global homogeneous edge IDs. # graph.edata[EID] are global homogeneous edge IDs.
eids = F.boolean_mask(graph.edata[EID], _get_inner_edge_mask(graph, etype_id)) eids = F.boolean_mask(
graph.edata[EID], _get_inner_edge_mask(graph, etype_id)
)
partids1 = gpb.eid2partid(eids) partids1 = gpb.eid2partid(eids)
_, per_type_eids = gpb.map_to_per_etype(eids) _, per_type_eids = gpb.map_to_per_etype(eids)
partids2 = gpb.eid2partid(per_type_eids, etype) partids2 = gpb.eid2partid(per_type_eids, etype)
assert np.all(F.asnumpy(partids1 == part_id)), \ assert np.all(F.asnumpy(partids1 == part_id)), (
'Unexpected partition IDs are found in the loaded partition ' \ "Unexpected partition IDs are found in the loaded partition "
'while querying via global homogeneous edge IDs.' "while querying via global homogeneous edge IDs."
assert np.all(F.asnumpy(partids2 == part_id)), \ )
'Unexpected partition IDs are found in the loaded partition ' \ assert np.all(F.asnumpy(partids2 == part_id)), (
'while querying via type-wise edge IDs.' "Unexpected partition IDs are found in the loaded partition "
"while querying via type-wise edge IDs."
)
node_feats = {} node_feats = {}
edge_feats = {} edge_feats = {}
if load_feats: if load_feats:
node_feats, edge_feats = load_partition_feats(part_config, part_id) node_feats, edge_feats = load_partition_feats(part_config, part_id)
return graph, node_feats, edge_feats, gpb, graph_name, ntypes_list, etypes_list return (
graph,
node_feats,
edge_feats,
gpb,
graph_name,
ntypes_list,
etypes_list,
)
def load_partition_feats(part_config, part_id, load_nodes=True, load_edges=True):
'''Load node/edge feature data from a partition. def load_partition_feats(
part_config, part_id, load_nodes=True, load_edges=True
):
"""Load node/edge feature data from a partition.
Parameters Parameters
---------- ----------
...@@ -230,45 +276,52 @@ def load_partition_feats(part_config, part_id, load_nodes=True, load_edges=True) ...@@ -230,45 +276,52 @@ def load_partition_feats(part_config, part_id, load_nodes=True, load_edges=True)
Node features. Node features.
Dict[str, Tensor] or None Dict[str, Tensor] or None
Edge features. Edge features.
''' """
config_path = os.path.dirname(part_config) config_path = os.path.dirname(part_config)
relative_to_config = lambda path: os.path.join(config_path, path) relative_to_config = lambda path: os.path.join(config_path, path)
with open(part_config) as conf_f: with open(part_config) as conf_f:
part_metadata = json.load(conf_f) part_metadata = json.load(conf_f)
assert 'part-{}'.format(part_id) in part_metadata, "part-{} does not exist".format(part_id) assert (
part_files = part_metadata['part-{}'.format(part_id)] "part-{}".format(part_id) in part_metadata
assert 'node_feats' in part_files, "the partition does not contain node features." ), "part-{} does not exist".format(part_id)
assert 'edge_feats' in part_files, "the partition does not contain edge feature." part_files = part_metadata["part-{}".format(part_id)]
assert (
"node_feats" in part_files
), "the partition does not contain node features."
assert (
"edge_feats" in part_files
), "the partition does not contain edge feature."
node_feats = None node_feats = None
if load_nodes: if load_nodes:
node_feats = load_tensors(relative_to_config(part_files['node_feats'])) node_feats = load_tensors(relative_to_config(part_files["node_feats"]))
edge_feats = None edge_feats = None
if load_edges: if load_edges:
edge_feats = load_tensors(relative_to_config(part_files['edge_feats'])) edge_feats = load_tensors(relative_to_config(part_files["edge_feats"]))
# In the old format, the feature name doesn't contain node/edge type. # In the old format, the feature name doesn't contain node/edge type.
# For compatibility, let's add node/edge types to the feature names. # For compatibility, let's add node/edge types to the feature names.
if node_feats is not None: if node_feats is not None:
new_feats = {} new_feats = {}
for name in node_feats: for name in node_feats:
feat = node_feats[name] feat = node_feats[name]
if name.find('/') == -1: if name.find("/") == -1:
name = DEFAULT_NTYPE + '/' + name name = DEFAULT_NTYPE + "/" + name
new_feats[name] = feat new_feats[name] = feat
node_feats = new_feats node_feats = new_feats
if edge_feats is not None: if edge_feats is not None:
new_feats = {} new_feats = {}
for name in edge_feats: for name in edge_feats:
feat = edge_feats[name] feat = edge_feats[name]
if name.find('/') == -1: if name.find("/") == -1:
name = _etype_tuple_to_str(DEFAULT_ETYPE) + '/' + name name = _etype_tuple_to_str(DEFAULT_ETYPE) + "/" + name
new_feats[name] = feat new_feats[name] = feat
edge_feats = new_feats edge_feats = new_feats
return node_feats, edge_feats return node_feats, edge_feats
def load_partition_book(part_config, part_id): def load_partition_book(part_config, part_id):
'''Load a graph partition book from the partition config file. """Load a graph partition book from the partition config file.
Parameters Parameters
---------- ----------
...@@ -287,23 +340,30 @@ def load_partition_book(part_config, part_id): ...@@ -287,23 +340,30 @@ def load_partition_book(part_config, part_id):
The node types The node types
dict dict
The edge types The edge types
''' """
part_metadata = _load_part_config(part_config) part_metadata = _load_part_config(part_config)
assert 'num_parts' in part_metadata, 'num_parts does not exist.' assert "num_parts" in part_metadata, "num_parts does not exist."
assert part_metadata['num_parts'] > part_id, \ assert (
'part {} is out of range (#parts: {})'.format(part_id, part_metadata['num_parts']) part_metadata["num_parts"] > part_id
num_parts = part_metadata['num_parts'] ), "part {} is out of range (#parts: {})".format(
assert 'num_nodes' in part_metadata, "cannot get the number of nodes of the global graph." part_id, part_metadata["num_parts"]
assert 'num_edges' in part_metadata, "cannot get the number of edges of the global graph." )
assert 'node_map' in part_metadata, "cannot get the node map." num_parts = part_metadata["num_parts"]
assert 'edge_map' in part_metadata, "cannot get the edge map." assert (
assert 'graph_name' in part_metadata, "cannot get the graph name" "num_nodes" in part_metadata
), "cannot get the number of nodes of the global graph."
assert (
"num_edges" in part_metadata
), "cannot get the number of edges of the global graph."
assert "node_map" in part_metadata, "cannot get the node map."
assert "edge_map" in part_metadata, "cannot get the edge map."
assert "graph_name" in part_metadata, "cannot get the graph name"
# If this is a range partitioning, node_map actually stores a list, whose elements # If this is a range partitioning, node_map actually stores a list, whose elements
# indicate the boundary of range partitioning. Otherwise, node_map stores a filename # indicate the boundary of range partitioning. Otherwise, node_map stores a filename
# that contains node map in a NumPy array. # that contains node map in a NumPy array.
node_map = part_metadata['node_map'] node_map = part_metadata["node_map"]
edge_map = part_metadata['edge_map'] edge_map = part_metadata["edge_map"]
if isinstance(node_map, dict): if isinstance(node_map, dict):
for key in node_map: for key in node_map:
is_range_part = isinstance(node_map[key], list) is_range_part = isinstance(node_map[key], list)
...@@ -318,28 +378,35 @@ def load_partition_book(part_config, part_id): ...@@ -318,28 +378,35 @@ def load_partition_book(part_config, part_id):
ntypes = {DEFAULT_NTYPE: 0} ntypes = {DEFAULT_NTYPE: 0}
etypes = {DEFAULT_ETYPE: 0} etypes = {DEFAULT_ETYPE: 0}
if 'ntypes' in part_metadata: if "ntypes" in part_metadata:
ntypes = part_metadata['ntypes'] ntypes = part_metadata["ntypes"]
if 'etypes' in part_metadata: if "etypes" in part_metadata:
etypes = part_metadata['etypes'] etypes = part_metadata["etypes"]
if isinstance(node_map, dict): if isinstance(node_map, dict):
for key in node_map: for key in node_map:
assert key in ntypes, 'The node type {} is invalid'.format(key) assert key in ntypes, "The node type {} is invalid".format(key)
if isinstance(edge_map, dict): if isinstance(edge_map, dict):
for key in edge_map: for key in edge_map:
assert key in etypes, 'The edge type {} is invalid'.format(key) assert key in etypes, "The edge type {} is invalid".format(key)
if not is_range_part: if not is_range_part:
raise TypeError("Only RangePartitionBook is supported currently.") raise TypeError("Only RangePartitionBook is supported currently.")
node_map = _get_part_ranges(node_map) node_map = _get_part_ranges(node_map)
edge_map = _get_part_ranges(edge_map) edge_map = _get_part_ranges(edge_map)
return RangePartitionBook(part_id, num_parts, node_map, edge_map, ntypes, etypes), \ return (
part_metadata['graph_name'], ntypes, etypes RangePartitionBook(
part_id, num_parts, node_map, edge_map, ntypes, etypes
),
part_metadata["graph_name"],
ntypes,
etypes,
)
def _get_orig_ids(g, sim_g, orig_nids, orig_eids): def _get_orig_ids(g, sim_g, orig_nids, orig_eids):
'''Convert/construct the original node IDs and edge IDs. """Convert/construct the original node IDs and edge IDs.
It handles multiple cases: It handles multiple cases:
* If the graph has been reshuffled and it's a homogeneous graph, we just return * If the graph has been reshuffled and it's a homogeneous graph, we just return
...@@ -363,7 +430,7 @@ def _get_orig_ids(g, sim_g, orig_nids, orig_eids): ...@@ -363,7 +430,7 @@ def _get_orig_ids(g, sim_g, orig_nids, orig_eids):
Returns Returns
------- -------
tensor or dict of tensors, tensor or dict of tensors tensor or dict of tensors, tensor or dict of tensors
''' """
is_hetero = not g.is_homogeneous is_hetero = not g.is_homogeneous
if is_hetero: if is_hetero:
# Get the type IDs # Get the type IDs
...@@ -372,14 +439,23 @@ def _get_orig_ids(g, sim_g, orig_nids, orig_eids): ...@@ -372,14 +439,23 @@ def _get_orig_ids(g, sim_g, orig_nids, orig_eids):
# Mapping between shuffled global IDs to original per-type IDs # Mapping between shuffled global IDs to original per-type IDs
orig_nids = F.gather_row(sim_g.ndata[NID], orig_nids) orig_nids = F.gather_row(sim_g.ndata[NID], orig_nids)
orig_eids = F.gather_row(sim_g.edata[EID], orig_eids) orig_eids = F.gather_row(sim_g.edata[EID], orig_eids)
orig_nids = {ntype: F.boolean_mask(orig_nids, orig_ntype == g.get_ntype_id(ntype)) \ orig_nids = {
for ntype in g.ntypes} ntype: F.boolean_mask(
orig_eids = {etype: F.boolean_mask(orig_eids, orig_etype == g.get_etype_id(etype)) \ orig_nids, orig_ntype == g.get_ntype_id(ntype)
for etype in g.canonical_etypes} )
for ntype in g.ntypes
}
orig_eids = {
etype: F.boolean_mask(
orig_eids, orig_etype == g.get_etype_id(etype)
)
for etype in g.canonical_etypes
}
return orig_nids, orig_eids return orig_nids, orig_eids
def _set_trainer_ids(g, sim_g, node_parts): def _set_trainer_ids(g, sim_g, node_parts):
'''Set the trainer IDs for each node and edge on the input graph. """Set the trainer IDs for each node and edge on the input graph.
The trainer IDs will be stored as node data and edge data in the input graph. The trainer IDs will be stored as node data and edge data in the input graph.
...@@ -391,29 +467,44 @@ def _set_trainer_ids(g, sim_g, node_parts): ...@@ -391,29 +467,44 @@ def _set_trainer_ids(g, sim_g, node_parts):
The homogeneous version of the input graph. The homogeneous version of the input graph.
node_parts : tensor node_parts : tensor
The node partition ID for each node in `sim_g`. The node partition ID for each node in `sim_g`.
''' """
if g.is_homogeneous: if g.is_homogeneous:
g.ndata['trainer_id'] = node_parts g.ndata["trainer_id"] = node_parts
# An edge is assigned to a partition based on its destination node. # An edge is assigned to a partition based on its destination node.
g.edata['trainer_id'] = F.gather_row(node_parts, g.edges()[1]) g.edata["trainer_id"] = F.gather_row(node_parts, g.edges()[1])
else: else:
for ntype_id, ntype in enumerate(g.ntypes): for ntype_id, ntype in enumerate(g.ntypes):
type_idx = sim_g.ndata[NTYPE] == ntype_id type_idx = sim_g.ndata[NTYPE] == ntype_id
orig_nid = F.boolean_mask(sim_g.ndata[NID], type_idx) orig_nid = F.boolean_mask(sim_g.ndata[NID], type_idx)
trainer_id = F.zeros((len(orig_nid),), F.dtype(node_parts), F.cpu()) trainer_id = F.zeros((len(orig_nid),), F.dtype(node_parts), F.cpu())
F.scatter_row_inplace(trainer_id, orig_nid, F.boolean_mask(node_parts, type_idx)) F.scatter_row_inplace(
g.nodes[ntype].data['trainer_id'] = trainer_id trainer_id, orig_nid, F.boolean_mask(node_parts, type_idx)
)
g.nodes[ntype].data["trainer_id"] = trainer_id
for c_etype in g.canonical_etypes: for c_etype in g.canonical_etypes:
# An edge is assigned to a partition based on its destination node. # An edge is assigned to a partition based on its destination node.
_, _, dst_type = c_etype _, _, dst_type = c_etype
trainer_id = F.gather_row(g.nodes[dst_type].data['trainer_id'], trainer_id = F.gather_row(
g.edges(etype=c_etype)[1]) g.nodes[dst_type].data["trainer_id"], g.edges(etype=c_etype)[1]
g.edges[c_etype].data['trainer_id'] = trainer_id )
g.edges[c_etype].data["trainer_id"] = trainer_id
def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method="metis",
balance_ntypes=None, balance_edges=False, return_mapping=False,
num_trainers_per_machine=1, objtype='cut', graph_formats=None): def partition_graph(
''' Partition a graph for distributed training and store the partitions on files. g,
graph_name,
num_parts,
out_path,
num_hops=1,
part_method="metis",
balance_ntypes=None,
balance_edges=False,
return_mapping=False,
num_trainers_per_machine=1,
objtype="cut",
graph_formats=None,
):
"""Partition a graph for distributed training and store the partitions on files.
The partitioning occurs in three steps: 1) run a partition algorithm (e.g., Metis) to The partitioning occurs in three steps: 1) run a partition algorithm (e.g., Metis) to
assign nodes to partitions; 2) construct partition graph structure based on assign nodes to partitions; 2) construct partition graph structure based on
...@@ -608,10 +699,12 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -608,10 +699,12 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
>>> ( >>> (
... g, node_feats, edge_feats, gpb, graph_name, ntypes_list, etypes_list, ... g, node_feats, edge_feats, gpb, graph_name, ntypes_list, etypes_list,
... ) = dgl.distributed.load_partition('output/test.json', 0) ... ) = dgl.distributed.load_partition('output/test.json', 0)
''' """
# 'coo' is required for partition # 'coo' is required for partition
assert 'coo' in np.concatenate(list(g.formats().values())), \ assert "coo" in np.concatenate(
"'coo' format should be allowed for partitioning graph." list(g.formats().values())
), "'coo' format should be allowed for partitioning graph."
def get_homogeneous(g, balance_ntypes): def get_homogeneous(g, balance_ntypes):
if g.is_homogeneous: if g.is_homogeneous:
sim_g = to_homogeneous(g) sim_g = to_homogeneous(g)
...@@ -626,49 +719,65 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -626,49 +719,65 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
num_ntypes = 0 num_ntypes = 0
for key in g.ntypes: for key in g.ntypes:
if key in balance_ntypes: if key in balance_ntypes:
g.nodes[key].data['bal_ntype'] = F.astype(balance_ntypes[key], g.nodes[key].data["bal_ntype"] = (
F.int32) + num_ntypes F.astype(balance_ntypes[key], F.int32) + num_ntypes
)
uniq_ntypes = F.unique(balance_ntypes[key]) uniq_ntypes = F.unique(balance_ntypes[key])
assert np.all(F.asnumpy(uniq_ntypes) == np.arange(len(uniq_ntypes))) assert np.all(
F.asnumpy(uniq_ntypes) == np.arange(len(uniq_ntypes))
)
num_ntypes += len(uniq_ntypes) num_ntypes += len(uniq_ntypes)
else: else:
g.nodes[key].data['bal_ntype'] = F.ones((g.number_of_nodes(key),), F.int32, g.nodes[key].data["bal_ntype"] = (
F.cpu()) * num_ntypes F.ones((g.number_of_nodes(key),), F.int32, F.cpu())
* num_ntypes
)
num_ntypes += 1 num_ntypes += 1
sim_g = to_homogeneous(g, ndata=['bal_ntype']) sim_g = to_homogeneous(g, ndata=["bal_ntype"])
bal_ntypes = sim_g.ndata['bal_ntype'] bal_ntypes = sim_g.ndata["bal_ntype"]
print('The graph has {} node types and balance among {} types'.format( print(
len(g.ntypes), len(F.unique(bal_ntypes)))) "The graph has {} node types and balance among {} types".format(
len(g.ntypes), len(F.unique(bal_ntypes))
)
)
# We now no longer need them. # We now no longer need them.
for key in g.ntypes: for key in g.ntypes:
del g.nodes[key].data['bal_ntype'] del g.nodes[key].data["bal_ntype"]
del sim_g.ndata['bal_ntype'] del sim_g.ndata["bal_ntype"]
else: else:
sim_g = to_homogeneous(g) sim_g = to_homogeneous(g)
bal_ntypes = sim_g.ndata[NTYPE] bal_ntypes = sim_g.ndata[NTYPE]
return sim_g, bal_ntypes return sim_g, bal_ntypes
if objtype not in ['cut', 'vol']: if objtype not in ["cut", "vol"]:
raise ValueError raise ValueError
if num_parts == 1: if num_parts == 1:
start = time.time() start = time.time()
sim_g, balance_ntypes = get_homogeneous(g, balance_ntypes) sim_g, balance_ntypes = get_homogeneous(g, balance_ntypes)
print('Converting to homogeneous graph takes {:.3f}s, peak mem: {:.3f} GB'.format( print(
time.time() - start, get_peak_mem())) "Converting to homogeneous graph takes {:.3f}s, peak mem: {:.3f} GB".format(
time.time() - start, get_peak_mem()
)
)
assert num_trainers_per_machine >= 1 assert num_trainers_per_machine >= 1
if num_trainers_per_machine > 1: if num_trainers_per_machine > 1:
# First partition the whole graph to each trainer and save the trainer ids in # First partition the whole graph to each trainer and save the trainer ids in
# the node feature "trainer_id". # the node feature "trainer_id".
start = time.time() start = time.time()
node_parts = metis_partition_assignment( node_parts = metis_partition_assignment(
sim_g, num_parts * num_trainers_per_machine, sim_g,
num_parts * num_trainers_per_machine,
balance_ntypes=balance_ntypes, balance_ntypes=balance_ntypes,
balance_edges=balance_edges, balance_edges=balance_edges,
mode='k-way') mode="k-way",
)
_set_trainer_ids(g, sim_g, node_parts) _set_trainer_ids(g, sim_g, node_parts)
print('Assigning nodes to METIS partitions takes {:.3f}s, peak mem: {:.3f} GB'.format( print(
time.time() - start, get_peak_mem())) "Assigning nodes to METIS partitions takes {:.3f}s, peak mem: {:.3f} GB".format(
time.time() - start, get_peak_mem()
)
)
node_parts = F.zeros((sim_g.number_of_nodes(),), F.int64, F.cpu()) node_parts = F.zeros((sim_g.number_of_nodes(),), F.int64, F.cpu())
parts = {0: sim_g.clone()} parts = {0: sim_g.clone()}
...@@ -676,60 +785,86 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -676,60 +785,86 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
orig_eids = parts[0].edata[EID] = F.arange(0, sim_g.number_of_edges()) orig_eids = parts[0].edata[EID] = F.arange(0, sim_g.number_of_edges())
# For one partition, we don't really shuffle nodes and edges. We just need to simulate # For one partition, we don't really shuffle nodes and edges. We just need to simulate
# it and set node data and edge data of orig_id. # it and set node data and edge data of orig_id.
parts[0].ndata['orig_id'] = orig_nids parts[0].ndata["orig_id"] = orig_nids
parts[0].edata['orig_id'] = orig_eids parts[0].edata["orig_id"] = orig_eids
if return_mapping: if return_mapping:
if g.is_homogeneous: if g.is_homogeneous:
orig_nids = F.arange(0, sim_g.number_of_nodes()) orig_nids = F.arange(0, sim_g.number_of_nodes())
orig_eids = F.arange(0, sim_g.number_of_edges()) orig_eids = F.arange(0, sim_g.number_of_edges())
else: else:
orig_nids = {ntype: F.arange(0, g.number_of_nodes(ntype)) orig_nids = {
for ntype in g.ntypes} ntype: F.arange(0, g.number_of_nodes(ntype))
orig_eids = {etype: F.arange(0, g.number_of_edges(etype)) for ntype in g.ntypes
for etype in g.canonical_etypes} }
parts[0].ndata['inner_node'] = F.ones((sim_g.number_of_nodes(),), orig_eids = {
RESERVED_FIELD_DTYPE['inner_node'], F.cpu()) etype: F.arange(0, g.number_of_edges(etype))
parts[0].edata['inner_edge'] = F.ones((sim_g.number_of_edges(),), for etype in g.canonical_etypes
RESERVED_FIELD_DTYPE['inner_edge'], F.cpu()) }
elif part_method in ('metis', 'random'): parts[0].ndata["inner_node"] = F.ones(
(sim_g.number_of_nodes(),),
RESERVED_FIELD_DTYPE["inner_node"],
F.cpu(),
)
parts[0].edata["inner_edge"] = F.ones(
(sim_g.number_of_edges(),),
RESERVED_FIELD_DTYPE["inner_edge"],
F.cpu(),
)
elif part_method in ("metis", "random"):
start = time.time() start = time.time()
sim_g, balance_ntypes = get_homogeneous(g, balance_ntypes) sim_g, balance_ntypes = get_homogeneous(g, balance_ntypes)
print('Converting to homogeneous graph takes {:.3f}s, peak mem: {:.3f} GB'.format( print(
time.time() - start, get_peak_mem())) "Converting to homogeneous graph takes {:.3f}s, peak mem: {:.3f} GB".format(
if part_method == 'metis': time.time() - start, get_peak_mem()
)
)
if part_method == "metis":
assert num_trainers_per_machine >= 1 assert num_trainers_per_machine >= 1
start = time.time() start = time.time()
if num_trainers_per_machine > 1: if num_trainers_per_machine > 1:
# First partition the whole graph to each trainer and save the trainer ids in # First partition the whole graph to each trainer and save the trainer ids in
# the node feature "trainer_id". # the node feature "trainer_id".
node_parts = metis_partition_assignment( node_parts = metis_partition_assignment(
sim_g, num_parts * num_trainers_per_machine, sim_g,
num_parts * num_trainers_per_machine,
balance_ntypes=balance_ntypes, balance_ntypes=balance_ntypes,
balance_edges=balance_edges, balance_edges=balance_edges,
mode='k-way', objtype=objtype) mode="k-way",
objtype=objtype,
)
_set_trainer_ids(g, sim_g, node_parts) _set_trainer_ids(g, sim_g, node_parts)
# And then coalesce the partitions of trainers on the same machine into one # And then coalesce the partitions of trainers on the same machine into one
# larger partition. # larger partition.
node_parts = F.floor_div(node_parts, num_trainers_per_machine) node_parts = F.floor_div(node_parts, num_trainers_per_machine)
else: else:
node_parts = metis_partition_assignment(sim_g, num_parts, node_parts = metis_partition_assignment(
sim_g,
num_parts,
balance_ntypes=balance_ntypes, balance_ntypes=balance_ntypes,
balance_edges=balance_edges, balance_edges=balance_edges,
objtype=objtype) objtype=objtype,
print('Assigning nodes to METIS partitions takes {:.3f}s, peak mem: {:.3f} GB'.format( )
time.time() - start, get_peak_mem())) print(
"Assigning nodes to METIS partitions takes {:.3f}s, peak mem: {:.3f} GB".format(
time.time() - start, get_peak_mem()
)
)
else: else:
node_parts = random_choice(num_parts, sim_g.number_of_nodes()) node_parts = random_choice(num_parts, sim_g.number_of_nodes())
start = time.time() start = time.time()
parts, orig_nids, orig_eids = partition_graph_with_halo(sim_g, node_parts, num_hops, parts, orig_nids, orig_eids = partition_graph_with_halo(
reshuffle=True) sim_g, node_parts, num_hops, reshuffle=True
print('Splitting the graph into partitions takes {:.3f}s, peak mem: {:.3f} GB'.format( )
time.time() - start, get_peak_mem())) print(
"Splitting the graph into partitions takes {:.3f}s, peak mem: {:.3f} GB".format(
time.time() - start, get_peak_mem()
)
)
if return_mapping: if return_mapping:
orig_nids, orig_eids = _get_orig_ids(g, sim_g, orig_nids, orig_eids) orig_nids, orig_eids = _get_orig_ids(g, sim_g, orig_nids, orig_eids)
else: else:
raise Exception('Unknown partitioning method: ' + part_method) raise Exception("Unknown partitioning method: " + part_method)
# If the input is a heterogeneous graph, get the original node types and original node IDs. # If the input is a heterogeneous graph, get the original node types and original node IDs.
# `part' has three types of node data at this point. # `part' has three types of node data at this point.
...@@ -738,38 +873,56 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -738,38 +873,56 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
# NID: the global node IDs in the reshuffled homogeneous version of the input graph. # NID: the global node IDs in the reshuffled homogeneous version of the input graph.
if not g.is_homogeneous: if not g.is_homogeneous:
for name in parts: for name in parts:
orig_ids = parts[name].ndata['orig_id'] orig_ids = parts[name].ndata["orig_id"]
ntype = F.gather_row(sim_g.ndata[NTYPE], orig_ids) ntype = F.gather_row(sim_g.ndata[NTYPE], orig_ids)
parts[name].ndata[NTYPE] = F.astype(ntype, RESERVED_FIELD_DTYPE[NTYPE]) parts[name].ndata[NTYPE] = F.astype(
assert np.all(F.asnumpy(ntype) == F.asnumpy(parts[name].ndata[NTYPE])) ntype, RESERVED_FIELD_DTYPE[NTYPE]
)
assert np.all(
F.asnumpy(ntype) == F.asnumpy(parts[name].ndata[NTYPE])
)
# Get the original edge types and original edge IDs. # Get the original edge types and original edge IDs.
orig_ids = parts[name].edata['orig_id'] orig_ids = parts[name].edata["orig_id"]
etype = F.gather_row(sim_g.edata[ETYPE], orig_ids) etype = F.gather_row(sim_g.edata[ETYPE], orig_ids)
parts[name].edata[ETYPE] = F.astype(etype, RESERVED_FIELD_DTYPE[ETYPE]) parts[name].edata[ETYPE] = F.astype(
assert np.all(F.asnumpy(etype) == F.asnumpy(parts[name].edata[ETYPE])) etype, RESERVED_FIELD_DTYPE[ETYPE]
)
assert np.all(
F.asnumpy(etype) == F.asnumpy(parts[name].edata[ETYPE])
)
# Calculate the global node IDs to per-node IDs mapping. # Calculate the global node IDs to per-node IDs mapping.
inner_ntype = F.boolean_mask(parts[name].ndata[NTYPE], inner_ntype = F.boolean_mask(
parts[name].ndata['inner_node'] == 1) parts[name].ndata[NTYPE], parts[name].ndata["inner_node"] == 1
inner_nids = F.boolean_mask(parts[name].ndata[NID], )
parts[name].ndata['inner_node'] == 1) inner_nids = F.boolean_mask(
parts[name].ndata[NID], parts[name].ndata["inner_node"] == 1
)
for ntype in g.ntypes: for ntype in g.ntypes:
inner_ntype_mask = inner_ntype == g.get_ntype_id(ntype) inner_ntype_mask = inner_ntype == g.get_ntype_id(ntype)
typed_nids = F.boolean_mask(inner_nids, inner_ntype_mask) typed_nids = F.boolean_mask(inner_nids, inner_ntype_mask)
# inner node IDs are in a contiguous ID range. # inner node IDs are in a contiguous ID range.
expected_range = np.arange(int(F.as_scalar(typed_nids[0])), expected_range = np.arange(
int(F.as_scalar(typed_nids[-1])) + 1) int(F.as_scalar(typed_nids[0])),
int(F.as_scalar(typed_nids[-1])) + 1,
)
assert np.all(F.asnumpy(typed_nids) == expected_range) assert np.all(F.asnumpy(typed_nids) == expected_range)
# Calculate the global edge IDs to per-edge IDs mapping. # Calculate the global edge IDs to per-edge IDs mapping.
inner_etype = F.boolean_mask(parts[name].edata[ETYPE], inner_etype = F.boolean_mask(
parts[name].edata['inner_edge'] == 1) parts[name].edata[ETYPE], parts[name].edata["inner_edge"] == 1
inner_eids = F.boolean_mask(parts[name].edata[EID], )
parts[name].edata['inner_edge'] == 1) inner_eids = F.boolean_mask(
parts[name].edata[EID], parts[name].edata["inner_edge"] == 1
)
for etype in g.canonical_etypes: for etype in g.canonical_etypes:
inner_etype_mask = inner_etype == g.get_etype_id(etype) inner_etype_mask = inner_etype == g.get_etype_id(etype)
typed_eids = np.sort(F.asnumpy(F.boolean_mask(inner_eids, inner_etype_mask))) typed_eids = np.sort(
assert np.all(typed_eids == np.arange(int(typed_eids[0]), F.asnumpy(F.boolean_mask(inner_eids, inner_etype_mask))
int(typed_eids[-1]) + 1)) )
assert np.all(
typed_eids
== np.arange(int(typed_eids[0]), int(typed_eids[-1]) + 1)
)
os.makedirs(out_path, mode=0o775, exist_ok=True) os.makedirs(out_path, mode=0o775, exist_ok=True)
tot_num_inner_edges = 0 tot_num_inner_edges = 0
...@@ -786,10 +939,18 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -786,10 +939,18 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
node_map_val[ntype] = [] node_map_val[ntype] = []
for i in parts: for i in parts:
inner_node_mask = _get_inner_node_mask(parts[i], ntype_id) inner_node_mask = _get_inner_node_mask(parts[i], ntype_id)
val.append(F.as_scalar(F.sum(F.astype(inner_node_mask, F.int64), 0))) val.append(
inner_nids = F.boolean_mask(parts[i].ndata[NID], inner_node_mask) F.as_scalar(F.sum(F.astype(inner_node_mask, F.int64), 0))
node_map_val[ntype].append([int(F.as_scalar(inner_nids[0])), )
int(F.as_scalar(inner_nids[-1])) + 1]) inner_nids = F.boolean_mask(
parts[i].ndata[NID], inner_node_mask
)
node_map_val[ntype].append(
[
int(F.as_scalar(inner_nids[0])),
int(F.as_scalar(inner_nids[-1])) + 1,
]
)
val = np.cumsum(val).tolist() val = np.cumsum(val).tolist()
assert val[-1] == g.number_of_nodes(ntype) assert val[-1] == g.number_of_nodes(ntype)
for etype in g.canonical_etypes: for etype in g.canonical_etypes:
...@@ -798,10 +959,17 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -798,10 +959,17 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
edge_map_val[etype] = [] edge_map_val[etype] = []
for i in parts: for i in parts:
inner_edge_mask = _get_inner_edge_mask(parts[i], etype_id) inner_edge_mask = _get_inner_edge_mask(parts[i], etype_id)
val.append(F.as_scalar(F.sum(F.astype(inner_edge_mask, F.int64), 0))) val.append(
inner_eids = np.sort(F.asnumpy(F.boolean_mask(parts[i].edata[EID], F.as_scalar(F.sum(F.astype(inner_edge_mask, F.int64), 0))
inner_edge_mask))) )
edge_map_val[etype].append([int(inner_eids[0]), int(inner_eids[-1]) + 1]) inner_eids = np.sort(
F.asnumpy(
F.boolean_mask(parts[i].edata[EID], inner_edge_mask)
)
)
edge_map_val[etype].append(
[int(inner_eids[0]), int(inner_eids[-1]) + 1]
)
val = np.cumsum(val).tolist() val = np.cumsum(val).tolist()
assert val[-1] == g.number_of_edges(etype) assert val[-1] == g.number_of_edges(etype)
else: else:
...@@ -811,14 +979,22 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -811,14 +979,22 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
ntype_id = g.get_ntype_id(ntype) ntype_id = g.get_ntype_id(ntype)
inner_node_mask = _get_inner_node_mask(parts[0], ntype_id) inner_node_mask = _get_inner_node_mask(parts[0], ntype_id)
inner_nids = F.boolean_mask(parts[0].ndata[NID], inner_node_mask) inner_nids = F.boolean_mask(parts[0].ndata[NID], inner_node_mask)
node_map_val[ntype] = [[int(F.as_scalar(inner_nids[0])), node_map_val[ntype] = [
int(F.as_scalar(inner_nids[-1])) + 1]] [
int(F.as_scalar(inner_nids[0])),
int(F.as_scalar(inner_nids[-1])) + 1,
]
]
for etype in g.canonical_etypes: for etype in g.canonical_etypes:
etype_id = g.get_etype_id(etype) etype_id = g.get_etype_id(etype)
inner_edge_mask = _get_inner_edge_mask(parts[0], etype_id) inner_edge_mask = _get_inner_edge_mask(parts[0], etype_id)
inner_eids = F.boolean_mask(parts[0].edata[EID], inner_edge_mask) inner_eids = F.boolean_mask(parts[0].edata[EID], inner_edge_mask)
edge_map_val[etype] = [[int(F.as_scalar(inner_eids[0])), edge_map_val[etype] = [
int(F.as_scalar(inner_eids[-1])) + 1]] [
int(F.as_scalar(inner_eids[0])),
int(F.as_scalar(inner_eids[-1])) + 1,
]
]
# Double check that the node IDs in the global ID space are sorted. # Double check that the node IDs in the global ID space are sorted.
for ntype in node_map_val: for ntype in node_map_val:
...@@ -829,18 +1005,20 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -829,18 +1005,20 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
assert np.all(val[:-1] <= val[1:]) assert np.all(val[:-1] <= val[1:])
start = time.time() start = time.time()
ntypes = {ntype:g.get_ntype_id(ntype) for ntype in g.ntypes} ntypes = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
etypes = {etype:g.get_etype_id(etype) for etype in g.canonical_etypes} etypes = {etype: g.get_etype_id(etype) for etype in g.canonical_etypes}
part_metadata = {'graph_name': graph_name, part_metadata = {
'num_nodes': g.number_of_nodes(), "graph_name": graph_name,
'num_edges': g.number_of_edges(), "num_nodes": g.number_of_nodes(),
'part_method': part_method, "num_edges": g.number_of_edges(),
'num_parts': num_parts, "part_method": part_method,
'halo_hops': num_hops, "num_parts": num_parts,
'node_map': node_map_val, "halo_hops": num_hops,
'edge_map': edge_map_val, "node_map": node_map_val,
'ntypes': ntypes, "edge_map": edge_map_val,
'etypes': etypes} "ntypes": ntypes,
"etypes": etypes,
}
for part_id in range(num_parts): for part_id in range(num_parts):
part = parts[part_id] part = parts[part_id]
...@@ -852,107 +1030,150 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -852,107 +1030,150 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
ntype_id = g.get_ntype_id(ntype) ntype_id = g.get_ntype_id(ntype)
# To get the edges in the input graph, we should use original node IDs. # To get the edges in the input graph, we should use original node IDs.
# Both orig_id and NID stores the per-node-type IDs. # Both orig_id and NID stores the per-node-type IDs.
ndata_name = 'orig_id' ndata_name = "orig_id"
inner_node_mask = _get_inner_node_mask(part, ntype_id) inner_node_mask = _get_inner_node_mask(part, ntype_id)
# This is global node IDs. # This is global node IDs.
local_nodes = F.boolean_mask(part.ndata[ndata_name], inner_node_mask) local_nodes = F.boolean_mask(
part.ndata[ndata_name], inner_node_mask
)
if len(g.ntypes) > 1: if len(g.ntypes) > 1:
# If the input is a heterogeneous graph. # If the input is a heterogeneous graph.
local_nodes = F.gather_row(sim_g.ndata[NID], local_nodes) local_nodes = F.gather_row(sim_g.ndata[NID], local_nodes)
print('part {} has {} nodes of type {} and {} are inside the partition'.format( print(
part_id, F.as_scalar(F.sum(part.ndata[NTYPE] == ntype_id, 0)), "part {} has {} nodes of type {} and {} are inside the partition".format(
ntype, len(local_nodes))) part_id,
F.as_scalar(
F.sum(part.ndata[NTYPE] == ntype_id, 0)
),
ntype,
len(local_nodes),
)
)
else: else:
print('part {} has {} nodes and {} are inside the partition'.format( print(
part_id, part.number_of_nodes(), len(local_nodes))) "part {} has {} nodes and {} are inside the partition".format(
part_id, part.number_of_nodes(), len(local_nodes)
)
)
for name in g.nodes[ntype].data: for name in g.nodes[ntype].data:
if name in [NID, 'inner_node']: if name in [NID, "inner_node"]:
continue continue
node_feats[ntype + '/' + name] = F.gather_row(g.nodes[ntype].data[name], node_feats[ntype + "/" + name] = F.gather_row(
local_nodes) g.nodes[ntype].data[name], local_nodes
)
for etype in g.canonical_etypes: for etype in g.canonical_etypes:
etype_id = g.get_etype_id(etype) etype_id = g.get_etype_id(etype)
edata_name = 'orig_id' edata_name = "orig_id"
inner_edge_mask = _get_inner_edge_mask(part, etype_id) inner_edge_mask = _get_inner_edge_mask(part, etype_id)
# This is global edge IDs. # This is global edge IDs.
local_edges = F.boolean_mask(part.edata[edata_name], inner_edge_mask) local_edges = F.boolean_mask(
part.edata[edata_name], inner_edge_mask
)
if not g.is_homogeneous: if not g.is_homogeneous:
local_edges = F.gather_row(sim_g.edata[EID], local_edges) local_edges = F.gather_row(sim_g.edata[EID], local_edges)
print('part {} has {} edges of type {} and {} are inside the partition'.format( print(
part_id, F.as_scalar(F.sum(part.edata[ETYPE] == etype_id, 0)), "part {} has {} edges of type {} and {} are inside the partition".format(
etype, len(local_edges))) part_id,
F.as_scalar(
F.sum(part.edata[ETYPE] == etype_id, 0)
),
etype,
len(local_edges),
)
)
else: else:
print('part {} has {} edges and {} are inside the partition'.format( print(
part_id, part.number_of_edges(), len(local_edges))) "part {} has {} edges and {} are inside the partition".format(
part_id, part.number_of_edges(), len(local_edges)
)
)
tot_num_inner_edges += len(local_edges) tot_num_inner_edges += len(local_edges)
for name in g.edges[etype].data: for name in g.edges[etype].data:
if name in [EID, 'inner_edge']: if name in [EID, "inner_edge"]:
continue continue
edge_feats[_etype_tuple_to_str(etype) + '/' + name] = F.gather_row( edge_feats[
g.edges[etype].data[name], local_edges) _etype_tuple_to_str(etype) + "/" + name
] = F.gather_row(g.edges[etype].data[name], local_edges)
else: else:
for ntype in g.ntypes: for ntype in g.ntypes:
if len(g.ntypes) > 1: if len(g.ntypes) > 1:
ndata_name = 'orig_id' ndata_name = "orig_id"
ntype_id = g.get_ntype_id(ntype) ntype_id = g.get_ntype_id(ntype)
inner_node_mask = _get_inner_node_mask(part, ntype_id) inner_node_mask = _get_inner_node_mask(part, ntype_id)
# This is global node IDs. # This is global node IDs.
local_nodes = F.boolean_mask(part.ndata[ndata_name], inner_node_mask) local_nodes = F.boolean_mask(
part.ndata[ndata_name], inner_node_mask
)
local_nodes = F.gather_row(sim_g.ndata[NID], local_nodes) local_nodes = F.gather_row(sim_g.ndata[NID], local_nodes)
else: else:
local_nodes = sim_g.ndata[NID] local_nodes = sim_g.ndata[NID]
for name in g.nodes[ntype].data: for name in g.nodes[ntype].data:
if name in [NID, 'inner_node']: if name in [NID, "inner_node"]:
continue continue
node_feats[ntype + '/' + name] = F.gather_row(g.nodes[ntype].data[name], node_feats[ntype + "/" + name] = F.gather_row(
local_nodes) g.nodes[ntype].data[name], local_nodes
)
for etype in g.canonical_etypes: for etype in g.canonical_etypes:
if not g.is_homogeneous: if not g.is_homogeneous:
edata_name = 'orig_id' edata_name = "orig_id"
etype_id = g.get_etype_id(etype) etype_id = g.get_etype_id(etype)
inner_edge_mask = _get_inner_edge_mask(part, etype_id) inner_edge_mask = _get_inner_edge_mask(part, etype_id)
# This is global edge IDs. # This is global edge IDs.
local_edges = F.boolean_mask(part.edata[edata_name], inner_edge_mask) local_edges = F.boolean_mask(
part.edata[edata_name], inner_edge_mask
)
local_edges = F.gather_row(sim_g.edata[EID], local_edges) local_edges = F.gather_row(sim_g.edata[EID], local_edges)
else: else:
local_edges = sim_g.edata[EID] local_edges = sim_g.edata[EID]
for name in g.edges[etype].data: for name in g.edges[etype].data:
if name in [EID, 'inner_edge']: if name in [EID, "inner_edge"]:
continue continue
edge_feats[_etype_tuple_to_str(etype) + '/' + name] = F.gather_row( edge_feats[
g.edges[etype].data[name], local_edges) _etype_tuple_to_str(etype) + "/" + name
] = F.gather_row(g.edges[etype].data[name], local_edges)
# delete `orig_id` from ndata/edata # delete `orig_id` from ndata/edata
del part.ndata['orig_id'] del part.ndata["orig_id"]
del part.edata['orig_id'] del part.edata["orig_id"]
part_dir = os.path.join(out_path, "part" + str(part_id)) part_dir = os.path.join(out_path, "part" + str(part_id))
node_feat_file = os.path.join(part_dir, "node_feat.dgl") node_feat_file = os.path.join(part_dir, "node_feat.dgl")
edge_feat_file = os.path.join(part_dir, "edge_feat.dgl") edge_feat_file = os.path.join(part_dir, "edge_feat.dgl")
part_graph_file = os.path.join(part_dir, "graph.dgl") part_graph_file = os.path.join(part_dir, "graph.dgl")
part_metadata['part-{}'.format(part_id)] = { part_metadata["part-{}".format(part_id)] = {
'node_feats': os.path.relpath(node_feat_file, out_path), "node_feats": os.path.relpath(node_feat_file, out_path),
'edge_feats': os.path.relpath(edge_feat_file, out_path), "edge_feats": os.path.relpath(edge_feat_file, out_path),
'part_graph': os.path.relpath(part_graph_file, out_path)} "part_graph": os.path.relpath(part_graph_file, out_path),
}
os.makedirs(part_dir, mode=0o775, exist_ok=True) os.makedirs(part_dir, mode=0o775, exist_ok=True)
save_tensors(node_feat_file, node_feats) save_tensors(node_feat_file, node_feats)
save_tensors(edge_feat_file, edge_feats) save_tensors(edge_feat_file, edge_feats)
sort_etypes = len(g.etypes) > 1 sort_etypes = len(g.etypes) > 1
_save_graphs(part_graph_file, [part], formats=graph_formats, _save_graphs(
sort_etypes=sort_etypes) part_graph_file,
print('Save partitions: {:.3f} seconds, peak memory: {:.3f} GB'.format( [part],
time.time() - start, get_peak_mem())) formats=graph_formats,
sort_etypes=sort_etypes,
_dump_part_config(f'{out_path}/{graph_name}.json', part_metadata) )
print(
"Save partitions: {:.3f} seconds, peak memory: {:.3f} GB".format(
time.time() - start, get_peak_mem()
)
)
_dump_part_config(f"{out_path}/{graph_name}.json", part_metadata)
num_cuts = sim_g.number_of_edges() - tot_num_inner_edges num_cuts = sim_g.number_of_edges() - tot_num_inner_edges
if num_parts == 1: if num_parts == 1:
num_cuts = 0 num_cuts = 0
print('There are {} edges in the graph and {} edge cuts for {} partitions.'.format( print(
g.number_of_edges(), num_cuts, num_parts)) "There are {} edges in the graph and {} edge cuts for {} partitions.".format(
g.number_of_edges(), num_cuts, num_parts
)
)
if return_mapping: if return_mapping:
return orig_nids, orig_eids return orig_nids, orig_eids
"""Define utility functions for shared memory.""" """Define utility functions for shared memory."""
from .. import backend as F from .. import backend as F, ndarray as nd
from .. import ndarray as nd
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
DTYPE_DICT = F.data_type_dict DTYPE_DICT = F.data_type_dict
......
...@@ -5,12 +5,14 @@ This kvstore is used when running in the standalone mode ...@@ -5,12 +5,14 @@ This kvstore is used when running in the standalone mode
from .. import backend as F from .. import backend as F
class KVClient(object): class KVClient(object):
''' The fake KVStore client. """The fake KVStore client.
This is to mimic the distributed KVStore client. It's used for DistGraph This is to mimic the distributed KVStore client. It's used for DistGraph
in standalone mode. in standalone mode.
''' """
def __init__(self): def __init__(self):
self._data = {} self._data = {}
self._all_possible_part_policy = {} self._all_possible_part_policy = {}
...@@ -30,25 +32,27 @@ class KVClient(object): ...@@ -30,25 +32,27 @@ class KVClient(object):
return 1 return 1
def barrier(self): def barrier(self):
'''barrier''' """barrier"""
def register_push_handler(self, name, func): def register_push_handler(self, name, func):
'''register push handler''' """register push handler"""
self._push_handlers[name] = func self._push_handlers[name] = func
def register_pull_handler(self, name, func): def register_pull_handler(self, name, func):
'''register pull handler''' """register pull handler"""
self._pull_handlers[name] = func self._pull_handlers[name] = func
def add_data(self, name, tensor, part_policy): def add_data(self, name, tensor, part_policy):
'''add data to the client''' """add data to the client"""
self._data[name] = tensor self._data[name] = tensor
self._gdata_name_list.add(name) self._gdata_name_list.add(name)
if part_policy.policy_str not in self._all_possible_part_policy: if part_policy.policy_str not in self._all_possible_part_policy:
self._all_possible_part_policy[part_policy.policy_str] = part_policy self._all_possible_part_policy[part_policy.policy_str] = part_policy
def init_data(self, name, shape, dtype, part_policy, init_func, is_gdata=True): def init_data(
'''add new data to the client''' self, name, shape, dtype, part_policy, init_func, is_gdata=True
):
"""add new data to the client"""
self._data[name] = init_func(shape, dtype) self._data[name] = init_func(shape, dtype)
if part_policy.policy_str not in self._all_possible_part_policy: if part_policy.policy_str not in self._all_possible_part_policy:
self._all_possible_part_policy[part_policy.policy_str] = part_policy self._all_possible_part_policy[part_policy.policy_str] = part_policy
...@@ -56,38 +60,38 @@ class KVClient(object): ...@@ -56,38 +60,38 @@ class KVClient(object):
self._gdata_name_list.add(name) self._gdata_name_list.add(name)
def delete_data(self, name): def delete_data(self, name):
'''delete the data''' """delete the data"""
del self._data[name] del self._data[name]
self._gdata_name_list.remove(name) self._gdata_name_list.remove(name)
def data_name_list(self): def data_name_list(self):
'''get the names of all data''' """get the names of all data"""
return list(self._data.keys()) return list(self._data.keys())
def gdata_name_list(self): def gdata_name_list(self):
'''get the names of graph data''' """get the names of graph data"""
return list(self._gdata_name_list) return list(self._gdata_name_list)
def get_data_meta(self, name): def get_data_meta(self, name):
'''get the metadata of data''' """get the metadata of data"""
return F.dtype(self._data[name]), F.shape(self._data[name]), None return F.dtype(self._data[name]), F.shape(self._data[name]), None
def push(self, name, id_tensor, data_tensor): def push(self, name, id_tensor, data_tensor):
'''push data to kvstore''' """push data to kvstore"""
if name in self._push_handlers: if name in self._push_handlers:
self._push_handlers[name](self._data, name, id_tensor, data_tensor) self._push_handlers[name](self._data, name, id_tensor, data_tensor)
else: else:
F.scatter_row_inplace(self._data[name], id_tensor, data_tensor) F.scatter_row_inplace(self._data[name], id_tensor, data_tensor)
def pull(self, name, id_tensor): def pull(self, name, id_tensor):
'''pull data from kvstore''' """pull data from kvstore"""
if name in self._pull_handlers: if name in self._pull_handlers:
return self._pull_handlers[name](self._data, name, id_tensor) return self._pull_handlers[name](self._data, name, id_tensor)
else: else:
return F.gather_row(self._data[name], id_tensor) return F.gather_row(self._data[name], id_tensor)
def map_shared_data(self, partition_book): def map_shared_data(self, partition_book):
'''Mapping shared-memory tensor from server to client.''' """Mapping shared-memory tensor from server to client."""
def count_nonzero(self, name): def count_nonzero(self, name):
"""Count nonzero value by pull request from KVServers. """Count nonzero value by pull request from KVServers.
...@@ -116,8 +120,7 @@ class KVClient(object): ...@@ -116,8 +120,7 @@ class KVClient(object):
return self._data return self._data
def union(self, operand1_name, operand2_name, output_name): def union(self, operand1_name, operand2_name, output_name):
"""Compute the union of two mask arrays in the KVStore. """Compute the union of two mask arrays in the KVStore."""
"""
self._data[output_name][:] = ( self._data[output_name][:] = (
self._data[operand1_name] | self._data[operand2_name] self._data[operand1_name] | self._data[operand2_name]
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment