"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6921393ae27f7ab9f3f25f9e772ec42cfdf82f63"
Unverified Commit 68377251 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable DistGraphServer to load graphbolt partitions (#7042)

parent fe78093f
...@@ -8,7 +8,7 @@ from collections.abc import MutableMapping ...@@ -8,7 +8,7 @@ from collections.abc import MutableMapping
import numpy as np import numpy as np
from .. import backend as F, heterograph_index from .. import backend as F, graphbolt as gb, heterograph_index
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from ..base import ALL, DGLError, EID, ETYPE, is_all, NID from ..base import ALL, DGLError, EID, ETYPE, is_all, NID
from ..convert import graph as dgl_graph, heterograph as dgl_heterograph from ..convert import graph as dgl_graph, heterograph as dgl_heterograph
...@@ -88,7 +88,9 @@ class InitGraphResponse(rpc.Response): ...@@ -88,7 +88,9 @@ class InitGraphResponse(rpc.Response):
self._graph_name = state self._graph_name = state
def _copy_graph_to_shared_mem(g, graph_name, graph_format): def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt):
if use_graphbolt:
return g.copy_to_shared_memory(graph_name)
new_g = g.shared_memory(graph_name, formats=graph_format) new_g = g.shared_memory(graph_name, formats=graph_format)
# We should share the node/edge data to the client explicitly instead of putting them # We should share the node/edge data to the client explicitly instead of putting them
# in the KVStore because some of the node/edge data may be duplicated. # in the KVStore because some of the node/edge data may be duplicated.
...@@ -298,6 +300,30 @@ class EdgeDataView(MutableMapping): ...@@ -298,6 +300,30 @@ class EdgeDataView(MutableMapping):
return repr(reprs) return repr(reprs)
def _format_partition(graph, graph_format):
"""Format the partition to the specified format."""
if isinstance(graph, gb.FusedCSCSamplingGraph):
return graph
# formatting dtype
# TODO(Rui) Formatting forcely is not a perfect solution.
# We'd better store all dtypes when mapping to shared memory
# and map back with original dtypes.
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in graph.ndata:
graph.ndata[k] = F.astype(graph.ndata[k], dtype)
if k in graph.edata:
graph.edata[k] = F.astype(graph.edata[k], dtype)
# Create the graph formats specified the users.
print(
"Start to create specified graph formats which may take "
"non-trivial time."
)
graph = graph.formats(graph_format)
graph.create_formats_()
print(f"Finished creating specified graph formats: {graph_format}")
return graph
class DistGraphServer(KVServer): class DistGraphServer(KVServer):
"""The DistGraph server. """The DistGraph server.
...@@ -330,6 +356,8 @@ class DistGraphServer(KVServer): ...@@ -330,6 +356,8 @@ class DistGraphServer(KVServer):
Disable shared memory. Disable shared memory.
graph_format : str or list of str graph_format : str or list of str
The graph formats. The graph formats.
use_graphbolt : bool
Whether to load GraphBolt partition. Default: False.
""" """
def __init__( def __init__(
...@@ -341,6 +369,7 @@ class DistGraphServer(KVServer): ...@@ -341,6 +369,7 @@ class DistGraphServer(KVServer):
part_config, part_config,
disable_shared_mem=False, disable_shared_mem=False,
graph_format=("csc", "coo"), graph_format=("csc", "coo"),
use_graphbolt=False,
): ):
super(DistGraphServer, self).__init__( super(DistGraphServer, self).__init__(
server_id=server_id, server_id=server_id,
...@@ -350,6 +379,7 @@ class DistGraphServer(KVServer): ...@@ -350,6 +379,7 @@ class DistGraphServer(KVServer):
) )
self.ip_config = ip_config self.ip_config = ip_config
self.num_servers = num_servers self.num_servers = num_servers
self.use_graphbolt = use_graphbolt
# Load graph partition data. # Load graph partition data.
if self.is_backup_server(): if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards. # The backup server doesn't load the graph partition. It'll initialized afterwards.
...@@ -367,32 +397,17 @@ class DistGraphServer(KVServer): ...@@ -367,32 +397,17 @@ class DistGraphServer(KVServer):
graph_name, graph_name,
ntypes, ntypes,
etypes, etypes,
) = load_partition(part_config, self.part_id, load_feats=False) ) = load_partition(
print("load " + graph_name) part_config,
# formatting dtype self.part_id,
# TODO(Rui) Formatting forcely is not a perfect solution. load_feats=False,
# We'd better store all dtypes when mapping to shared memory use_graphbolt=use_graphbolt,
# and map back with original dtypes.
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in self.client_g.ndata:
self.client_g.ndata[k] = F.astype(
self.client_g.ndata[k], dtype
)
if k in self.client_g.edata:
self.client_g.edata[k] = F.astype(
self.client_g.edata[k], dtype
)
# Create the graph formats specified the users.
print(
"Start to create specified graph formats which may take "
"non-trivial time."
) )
self.client_g = self.client_g.formats(graph_format) print("load " + graph_name)
self.client_g.create_formats_() self.client_g = _format_partition(self.client_g, graph_format)
print("Finished creating specified graph formats.")
if not disable_shared_mem: if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem( self.client_g = _copy_graph_to_shared_mem(
self.client_g, graph_name, graph_format self.client_g, graph_name, graph_format, use_graphbolt
) )
if not disable_shared_mem: if not disable_shared_mem:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment