Unverified Commit b1309217 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] Reduce peak memory in DistDGL (#4687)

* [Dist] Reduce peak memory in DistDGL: avoid validation, release memory once loaded

* remove orig_id from ndata/edata for partition_graph()

* delete orig_id from ndata/edata in dist part pipeline

* reduce dtype size and format before saving graphs

* fix lint

* ETYPE requires to be int32/64 for CSRSortByTag

* fix test failure

* refine
parent 0a33500c
...@@ -4,6 +4,7 @@ from collections.abc import MutableMapping ...@@ -4,6 +4,7 @@ from collections.abc import MutableMapping
from collections import namedtuple from collections import namedtuple
import os import os
import gc
import numpy as np import numpy as np
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
...@@ -12,7 +13,7 @@ from ..convert import graph as dgl_graph ...@@ -12,7 +13,7 @@ from ..convert import graph as dgl_graph
from ..transforms import compact_graphs, sort_csr_by_tag, sort_csc_by_tag from ..transforms import compact_graphs, sort_csr_by_tag, sort_csc_by_tag
from .. import heterograph_index from .. import heterograph_index
from .. import backend as F from .. import backend as F
from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all from ..base import NID, EID, ETYPE, ALL, is_all
from .kvstore import KVServer, get_kvstore from .kvstore import KVServer, get_kvstore
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from ..ndarray import exist_shared_mem_array from ..ndarray import exist_shared_mem_array
...@@ -31,6 +32,7 @@ from .graph_services import find_edges as dist_find_edges ...@@ -31,6 +32,7 @@ from .graph_services import find_edges as dist_find_edges
from .graph_services import out_degrees as dist_out_degrees from .graph_services import out_degrees as dist_out_degrees
from .graph_services import in_degrees as dist_in_degrees from .graph_services import in_degrees as dist_in_degrees
from .dist_tensor import DistTensor from .dist_tensor import DistTensor
from .partition import RESERVED_FIELD_DTYPE
INIT_GRAPH = 800001 INIT_GRAPH = 800001
...@@ -84,13 +86,6 @@ def _copy_graph_to_shared_mem(g, graph_name, graph_format): ...@@ -84,13 +86,6 @@ def _copy_graph_to_shared_mem(g, graph_name, graph_format):
new_g.edata[ETYPE] = _to_shared_mem(g.edata[ETYPE], _get_edata_path(graph_name, ETYPE)) new_g.edata[ETYPE] = _to_shared_mem(g.edata[ETYPE], _get_edata_path(graph_name, ETYPE))
return new_g return new_g
FIELD_DICT = {'inner_node': F.int32, # A flag indicates whether the node is inside a partition.
'inner_edge': F.int32, # A flag indicates whether the edge is inside a partition.
NID: F.int64,
EID: F.int64,
NTYPE: F.int32,
ETYPE: F.int32}
def _get_shared_mem_ndata(g, graph_name, name): def _get_shared_mem_ndata(g, graph_name, name):
''' Get shared-memory node data from DistGraph server. ''' Get shared-memory node data from DistGraph server.
...@@ -98,7 +93,7 @@ def _get_shared_mem_ndata(g, graph_name, name): ...@@ -98,7 +93,7 @@ def _get_shared_mem_ndata(g, graph_name, name):
with shared memory. with shared memory.
''' '''
shape = (g.number_of_nodes(),) shape = (g.number_of_nodes(),)
dtype = FIELD_DICT[name] dtype = RESERVED_FIELD_DTYPE[name]
dtype = DTYPE_DICT[dtype] dtype = DTYPE_DICT[dtype]
data = empty_shared_mem(_get_ndata_path(graph_name, name), False, shape, dtype) data = empty_shared_mem(_get_ndata_path(graph_name, name), False, shape, dtype)
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
...@@ -111,7 +106,7 @@ def _get_shared_mem_edata(g, graph_name, name): ...@@ -111,7 +106,7 @@ def _get_shared_mem_edata(g, graph_name, name):
with shared memory. with shared memory.
''' '''
shape = (g.number_of_edges(),) shape = (g.number_of_edges(),)
dtype = FIELD_DICT[name] dtype = RESERVED_FIELD_DTYPE[name]
dtype = DTYPE_DICT[dtype] dtype = DTYPE_DICT[dtype]
data = empty_shared_mem(_get_edata_path(graph_name, name), False, shape, dtype) data = empty_shared_mem(_get_edata_path(graph_name, name), False, shape, dtype)
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
...@@ -340,7 +335,7 @@ class DistGraphServer(KVServer): ...@@ -340,7 +335,7 @@ class DistGraphServer(KVServer):
# TODO(Rui) Formatting forcely is not a perfect solution. # TODO(Rui) Formatting forcely is not a perfect solution.
# We'd better store all dtypes when mapping to shared memory # We'd better store all dtypes when mapping to shared memory
# and map back with original dtypes. # and map back with original dtypes.
for k, dtype in FIELD_DICT.items(): for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in self.client_g.ndata: if k in self.client_g.ndata:
self.client_g.ndata[k] = F.astype( self.client_g.ndata[k] = F.astype(
self.client_g.ndata[k], dtype) self.client_g.ndata[k], dtype)
...@@ -372,7 +367,8 @@ class DistGraphServer(KVServer): ...@@ -372,7 +367,8 @@ class DistGraphServer(KVServer):
self.add_part_policy(PartitionPolicy(edge_name.policy_str, self.gpb)) self.add_part_policy(PartitionPolicy(edge_name.policy_str, self.gpb))
if not self.is_backup_server(): if not self.is_backup_server():
node_feats, edge_feats = load_partition_feats(part_config, self.part_id) node_feats, _ = load_partition_feats(part_config, self.part_id,
load_nodes=True, load_edges=False)
for name in node_feats: for name in node_feats:
# The feature name has the following format: node_type + "/" + feature_name to avoid # The feature name has the following format: node_type + "/" + feature_name to avoid
# feature name collision for different node types. # feature name collision for different node types.
...@@ -381,6 +377,11 @@ class DistGraphServer(KVServer): ...@@ -381,6 +377,11 @@ class DistGraphServer(KVServer):
self.init_data(name=str(data_name), policy_str=data_name.policy_str, self.init_data(name=str(data_name), policy_str=data_name.policy_str,
data_tensor=node_feats[name]) data_tensor=node_feats[name])
self.orig_data.add(str(data_name)) self.orig_data.add(str(data_name))
# Let's free once node features are copied to shared memory
del node_feats
gc.collect()
_, edge_feats = load_partition_feats(part_config, self.part_id,
load_nodes=False, load_edges=True)
for name in edge_feats: for name in edge_feats:
# The feature name has the following format: edge_type + "/" + feature_name to avoid # The feature name has the following format: edge_type + "/" + feature_name to avoid
# feature name collision for different edge types. # feature name collision for different edge types.
...@@ -389,6 +390,9 @@ class DistGraphServer(KVServer): ...@@ -389,6 +390,9 @@ class DistGraphServer(KVServer):
self.init_data(name=str(data_name), policy_str=data_name.policy_str, self.init_data(name=str(data_name), policy_str=data_name.policy_str,
data_tensor=edge_feats[name]) data_tensor=edge_feats[name])
self.orig_data.add(str(data_name)) self.orig_data.add(str(data_name))
# Let's free once edge features are copied to shared memory
del edge_feats
gc.collect()
def start(self): def start(self):
""" Start graph store server. """ Start graph store server.
......
...@@ -13,6 +13,27 @@ from ..data.utils import load_graphs, save_graphs, load_tensors, save_tensors ...@@ -13,6 +13,27 @@ from ..data.utils import load_graphs, save_graphs, load_tensors, save_tensors
from ..partition import metis_partition_assignment, partition_graph_with_halo, get_peak_mem from ..partition import metis_partition_assignment, partition_graph_with_halo, get_peak_mem
from .graph_partition_book import BasicPartitionBook, RangePartitionBook from .graph_partition_book import BasicPartitionBook, RangePartitionBook
RESERVED_FIELD_DTYPE = {
'inner_node': F.uint8, # A flag indicates whether the node is inside a partition.
'inner_edge': F.uint8, # A flag indicates whether the edge is inside a partition.
NID: F.int64,
EID: F.int64,
NTYPE: F.int16,
# `sort_csr_by_tag` and `sort_csc_by_tag` works on int32/64 only.
ETYPE: F.int32
}
def _save_graphs(filename, g_list):
'''Format data types in graphs before saving
'''
for g in g_list:
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in g.ndata:
g.ndata[k] = F.astype(g.ndata[k], dtype)
if k in g.edata:
g.edata[k] = F.astype(g.edata[k], dtype)
save_graphs(filename , g_list)
def _get_inner_node_mask(graph, ntype_id): def _get_inner_node_mask(graph, ntype_id):
if NTYPE in graph.ndata: if NTYPE in graph.ndata:
dtype = F.dtype(graph.ndata['inner_node']) dtype = F.dtype(graph.ndata['inner_node'])
...@@ -96,7 +117,9 @@ def load_partition(part_config, part_id, load_feats=True): ...@@ -96,7 +117,9 @@ def load_partition(part_config, part_id, load_feats=True):
assert EID in graph.edata, "the partition graph should contain edge mapping to global edge ID" assert EID in graph.edata, "the partition graph should contain edge mapping to global edge ID"
gpb, graph_name, ntypes, etypes = load_partition_book(part_config, part_id, graph) gpb, graph_name, ntypes, etypes = load_partition_book(part_config, part_id, graph)
ntypes_list, etypes_list = [], [] ntypes_list = list(ntypes.keys())
etypes_list = list(etypes.keys())
if 'DGL_DIST_DEBUG' in os.environ:
for ntype in ntypes: for ntype in ntypes:
ntype_id = ntypes[ntype] ntype_id = ntypes[ntype]
# graph.ndata[NID] are global homogeneous node IDs. # graph.ndata[NID] are global homogeneous node IDs.
...@@ -104,9 +127,12 @@ def load_partition(part_config, part_id, load_feats=True): ...@@ -104,9 +127,12 @@ def load_partition(part_config, part_id, load_feats=True):
partids1 = gpb.nid2partid(nids) partids1 = gpb.nid2partid(nids)
_, per_type_nids = gpb.map_to_per_ntype(nids) _, per_type_nids = gpb.map_to_per_ntype(nids)
partids2 = gpb.nid2partid(per_type_nids, ntype) partids2 = gpb.nid2partid(per_type_nids, ntype)
assert np.all(F.asnumpy(partids1 == part_id)), 'load a wrong partition' assert np.all(F.asnumpy(partids1 == part_id)), \
assert np.all(F.asnumpy(partids2 == part_id)), 'load a wrong partition' 'Unexpected partition IDs are found in the loaded partition ' \
ntypes_list.append(ntype) 'while querying via global homogeneous node IDs.'
assert np.all(F.asnumpy(partids2 == part_id)), \
'Unexpected partition IDs are found in the loaded partition ' \
'while querying via type-wise node IDs.'
for etype in etypes: for etype in etypes:
etype_id = etypes[etype] etype_id = etypes[etype]
# graph.edata[EID] are global homogeneous edge IDs. # graph.edata[EID] are global homogeneous edge IDs.
...@@ -114,9 +140,12 @@ def load_partition(part_config, part_id, load_feats=True): ...@@ -114,9 +140,12 @@ def load_partition(part_config, part_id, load_feats=True):
partids1 = gpb.eid2partid(eids) partids1 = gpb.eid2partid(eids)
_, per_type_eids = gpb.map_to_per_etype(eids) _, per_type_eids = gpb.map_to_per_etype(eids)
partids2 = gpb.eid2partid(per_type_eids, etype) partids2 = gpb.eid2partid(per_type_eids, etype)
assert np.all(F.asnumpy(partids1 == part_id)), 'load a wrong partition' assert np.all(F.asnumpy(partids1 == part_id)), \
assert np.all(F.asnumpy(partids2 == part_id)), 'load a wrong partition' 'Unexpected partition IDs are found in the loaded partition ' \
etypes_list.append(etype) 'while querying via global homogeneous edge IDs.'
assert np.all(F.asnumpy(partids2 == part_id)), \
'Unexpected partition IDs are found in the loaded partition ' \
'while querying via type-wise edge IDs.'
node_feats = {} node_feats = {}
edge_feats = {} edge_feats = {}
...@@ -125,7 +154,7 @@ def load_partition(part_config, part_id, load_feats=True): ...@@ -125,7 +154,7 @@ def load_partition(part_config, part_id, load_feats=True):
return graph, node_feats, edge_feats, gpb, graph_name, ntypes_list, etypes_list return graph, node_feats, edge_feats, gpb, graph_name, ntypes_list, etypes_list
def load_partition_feats(part_config, part_id): def load_partition_feats(part_config, part_id, load_nodes=True, load_edges=True):
'''Load node/edge feature data from a partition. '''Load node/edge feature data from a partition.
Parameters Parameters
...@@ -134,12 +163,16 @@ def load_partition_feats(part_config, part_id): ...@@ -134,12 +163,16 @@ def load_partition_feats(part_config, part_id):
The path of the partition config file. The path of the partition config file.
part_id : int part_id : int
The partition ID. The partition ID.
load_nodes : bool, optional
Whether to load node features. If ``False``, ``None`` is returned.
load_edges : bool, optional
Whether to load edge features. If ``False``, ``None`` is returned.
Returns Returns
------- -------
Dict[str, Tensor] Dict[str, Tensor] or None
Node features. Node features.
Dict[str, Tensor] Dict[str, Tensor] or None
Edge features. Edge features.
''' '''
config_path = os.path.dirname(part_config) config_path = os.path.dirname(part_config)
...@@ -151,24 +184,30 @@ def load_partition_feats(part_config, part_id): ...@@ -151,24 +184,30 @@ def load_partition_feats(part_config, part_id):
part_files = part_metadata['part-{}'.format(part_id)] part_files = part_metadata['part-{}'.format(part_id)]
assert 'node_feats' in part_files, "the partition does not contain node features." assert 'node_feats' in part_files, "the partition does not contain node features."
assert 'edge_feats' in part_files, "the partition does not contain edge feature." assert 'edge_feats' in part_files, "the partition does not contain edge feature."
node_feats = None
if load_nodes:
node_feats = load_tensors(relative_to_config(part_files['node_feats'])) node_feats = load_tensors(relative_to_config(part_files['node_feats']))
edge_feats = None
if load_edges:
edge_feats = load_tensors(relative_to_config(part_files['edge_feats'])) edge_feats = load_tensors(relative_to_config(part_files['edge_feats']))
# In the old format, the feature name doesn't contain node/edge type. # In the old format, the feature name doesn't contain node/edge type.
# For compatibility, let's add node/edge types to the feature names. # For compatibility, let's add node/edge types to the feature names.
node_feats1 = {} if node_feats is not None:
edge_feats1 = {} new_feats = {}
for name in node_feats: for name in node_feats:
feat = node_feats[name] feat = node_feats[name]
if name.find('/') == -1: if name.find('/') == -1:
name = '_N/' + name name = '_N/' + name
node_feats1[name] = feat new_feats[name] = feat
node_feats = new_feats
if edge_feats is not None:
new_feats = {}
for name in edge_feats: for name in edge_feats:
feat = edge_feats[name] feat = edge_feats[name]
if name.find('/') == -1: if name.find('/') == -1:
name = '_E/' + name name = '_E/' + name
edge_feats1[name] = feat new_feats[name] = feat
node_feats = node_feats1 edge_feats = new_feats
edge_feats = edge_feats1
return node_feats, edge_feats return node_feats, edge_feats
...@@ -448,13 +487,11 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -448,13 +487,11 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
under name `dgl.EID`. For a heterogeneous graph, the DGLGraph also contains a node under name `dgl.EID`. For a heterogeneous graph, the DGLGraph also contains a node
data `dgl.NTYPE` for node type and an edge data `dgl.ETYPE` for the edge type. data `dgl.NTYPE` for node type and an edge data `dgl.ETYPE` for the edge type.
The partition graph contains additional node data ("inner_node" and "orig_id") and The partition graph contains additional node data ("inner_node") and
edge data ("inner_edge"): edge data ("inner_edge"):
* "inner_node" indicates whether a node belongs to a partition. * "inner_node" indicates whether a node belongs to a partition.
* "inner_edge" indicates whether an edge belongs to a partition. * "inner_edge" indicates whether an edge belongs to a partition.
* "orig_id" exists when reshuffle=True. It indicates the original node IDs in the original
graph before reshuffling.
Node and edge features are splitted and stored together with each graph partition. Node and edge features are splitted and stored together with each graph partition.
All node/edge features in a partition are stored in a file with DGL format. The node/edge All node/edge features in a partition are stored in a file with DGL format. The node/edge
...@@ -608,8 +645,10 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -608,8 +645,10 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
parts[0].edata['orig_id'] = orig_eids parts[0].edata['orig_id'] = orig_eids
if return_mapping: if return_mapping:
orig_nids, orig_eids = _get_orig_ids(g, sim_g, False, orig_nids, orig_eids) orig_nids, orig_eids = _get_orig_ids(g, sim_g, False, orig_nids, orig_eids)
parts[0].ndata['inner_node'] = F.ones((sim_g.number_of_nodes(),), F.int8, F.cpu()) parts[0].ndata['inner_node'] = F.ones((sim_g.number_of_nodes(),),
parts[0].edata['inner_edge'] = F.ones((sim_g.number_of_edges(),), F.int8, F.cpu()) RESERVED_FIELD_DTYPE['inner_node'], F.cpu())
parts[0].edata['inner_edge'] = F.ones((sim_g.number_of_edges(),),
RESERVED_FIELD_DTYPE['inner_edge'], F.cpu())
elif part_method in ('metis', 'random'): elif part_method in ('metis', 'random'):
start = time.time() start = time.time()
sim_g, balance_ntypes = get_homogeneous(g, balance_ntypes) sim_g, balance_ntypes = get_homogeneous(g, balance_ntypes)
...@@ -660,12 +699,12 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -660,12 +699,12 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
for name in parts: for name in parts:
orig_ids = parts[name].ndata['orig_id'] orig_ids = parts[name].ndata['orig_id']
ntype = F.gather_row(sim_g.ndata[NTYPE], orig_ids) ntype = F.gather_row(sim_g.ndata[NTYPE], orig_ids)
parts[name].ndata[NTYPE] = F.astype(ntype, F.int32) parts[name].ndata[NTYPE] = F.astype(ntype, RESERVED_FIELD_DTYPE[NTYPE])
assert np.all(F.asnumpy(ntype) == F.asnumpy(parts[name].ndata[NTYPE])) assert np.all(F.asnumpy(ntype) == F.asnumpy(parts[name].ndata[NTYPE]))
# Get the original edge types and original edge IDs. # Get the original edge types and original edge IDs.
orig_ids = parts[name].edata['orig_id'] orig_ids = parts[name].edata['orig_id']
etype = F.gather_row(sim_g.edata[ETYPE], orig_ids) etype = F.gather_row(sim_g.edata[ETYPE], orig_ids)
parts[name].edata[ETYPE] = F.astype(etype, F.int32) parts[name].edata[ETYPE] = F.astype(etype, RESERVED_FIELD_DTYPE[ETYPE])
assert np.all(F.asnumpy(etype) == F.asnumpy(parts[name].edata[ETYPE])) assert np.all(F.asnumpy(etype) == F.asnumpy(parts[name].edata[ETYPE]))
# Calculate the global node IDs to per-node IDs mapping. # Calculate the global node IDs to per-node IDs mapping.
...@@ -873,10 +912,10 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -873,10 +912,10 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
local_edges) local_edges)
else: else:
edge_feats[etype + '/' + name] = g.edges[etype].data[name] edge_feats[etype + '/' + name] = g.edges[etype].data[name]
# Some adjustment for heterogeneous graphs. # delete `orig_id` from ndata/edata
if not g.is_homogeneous: if reshuffle:
part.ndata['orig_id'] = F.gather_row(sim_g.ndata[NID], part.ndata['orig_id']) del part.ndata['orig_id']
part.edata['orig_id'] = F.gather_row(sim_g.edata[EID], part.edata['orig_id']) del part.edata['orig_id']
part_dir = os.path.join(out_path, "part" + str(part_id)) part_dir = os.path.join(out_path, "part" + str(part_id))
node_feat_file = os.path.join(part_dir, "node_feat.dgl") node_feat_file = os.path.join(part_dir, "node_feat.dgl")
...@@ -890,7 +929,7 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -890,7 +929,7 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
save_tensors(node_feat_file, node_feats) save_tensors(node_feat_file, node_feats)
save_tensors(edge_feat_file, edge_feats) save_tensors(edge_feat_file, edge_feats)
save_graphs(part_graph_file, [part]) _save_graphs(part_graph_file, [part])
print('Save partitions: {:.3f} seconds, peak memory: {:.3f} GB'.format( print('Save partitions: {:.3f} seconds, peak memory: {:.3f} GB'.format(
time.time() - start, get_peak_mem())) time.time() - start, get_peak_mem()))
......
...@@ -56,7 +56,7 @@ def run_server( ...@@ -56,7 +56,7 @@ def run_server(
print("start server", server_id) print("start server", server_id)
# verify dtype of underlying graph # verify dtype of underlying graph
cg = g.client_g cg = g.client_g
for k, dtype in dgl.distributed.dist_graph.FIELD_DICT.items(): for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
if k in cg.ndata: if k in cg.ndata:
assert ( assert (
F.dtype(cg.ndata[k]) == dtype F.dtype(cg.ndata[k]) == dtype
......
...@@ -41,7 +41,8 @@ def start_sample_client(rank, tmpdir, disable_shared_mem): ...@@ -41,7 +41,8 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
return sampled_graph return sampled_graph
def start_sample_client_shuffle(rank, tmpdir, disable_shared_mem, g, num_servers, group_id=0): def start_sample_client_shuffle(rank, tmpdir, disable_shared_mem, g, num_servers, group_id,
orig_nid, orig_eid):
os.environ['DGL_GROUP_ID'] = str(group_id) os.environ['DGL_GROUP_ID'] = str(group_id)
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
...@@ -50,13 +51,6 @@ def start_sample_client_shuffle(rank, tmpdir, disable_shared_mem, g, num_servers ...@@ -50,13 +51,6 @@ def start_sample_client_shuffle(rank, tmpdir, disable_shared_mem, g, num_servers
dist_graph = DistGraph("test_sampling", gpb=gpb) dist_graph = DistGraph("test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
for i in range(num_servers):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
src, dst = sampled_graph.edges() src, dst = sampled_graph.edges()
src = orig_nid[src] src = orig_nid[src]
dst = orig_nid[dst] dst = orig_nid[dst]
...@@ -232,8 +226,8 @@ def check_rpc_get_degree_shuffle(tmpdir, num_server): ...@@ -232,8 +226,8 @@ def check_rpc_get_degree_shuffle(tmpdir, num_server):
g.readonly() g.readonly()
num_parts = num_server num_parts = num_server
partition_graph(g, 'test_get_degrees', num_parts, tmpdir, orig_nid, _ = partition_graph(g, 'test_get_degrees', num_parts, tmpdir,
num_hops=1, part_method='metis', reshuffle=True) num_hops=1, part_method='metis', reshuffle=True, return_mapping=True)
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
...@@ -243,11 +237,6 @@ def check_rpc_get_degree_shuffle(tmpdir, num_server): ...@@ -243,11 +237,6 @@ def check_rpc_get_degree_shuffle(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_get_degrees.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
nids = F.tensor(np.random.randint(g.number_of_nodes(), size=100)) nids = F.tensor(np.random.randint(g.number_of_nodes(), size=100))
in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(0, tmpdir, num_server > 1, nids) in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(0, tmpdir, num_server > 1, nids)
...@@ -291,8 +280,8 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): ...@@ -291,8 +280,8 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_parts = num_server num_parts = num_server
num_hops = 1 num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir, orig_nids, orig_eids = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True) num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True)
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
...@@ -308,7 +297,9 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): ...@@ -308,7 +297,9 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_clients = 1 num_clients = 1
for client_id in range(num_clients): for client_id in range(num_clients):
for group_id in range(num_groups): for group_id in range(num_groups):
p = ctx.Process(target=start_sample_client_shuffle, args=(client_id, tmpdir, num_server > 1, g, num_server, group_id)) p = ctx.Process(target=start_sample_client_shuffle,
args=(client_id, tmpdir, num_server > 1, g, num_server,
group_id, orig_nids, orig_eids))
p.start() p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph time.sleep(1) # avoid race condition when instantiating DistGraph
pclient_list.append(p) pclient_list.append(p)
...@@ -384,8 +375,8 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): ...@@ -384,8 +375,8 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
num_parts = num_server num_parts = num_server
num_hops = 1 num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir, orig_nid_map, orig_eid_map = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True) num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True)
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
...@@ -401,21 +392,6 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): ...@@ -401,21 +392,6 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
for p in pserver_list: for p in pserver_list:
p.join() p.join()
orig_nid_map = {ntype: F.zeros((g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes}
orig_eid_map = {etype: F.zeros((g.number_of_edges(etype),), dtype=F.int64) for etype in g.etypes}
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID])
for ntype_id, ntype in enumerate(g.ntypes):
idx = ntype_ids == ntype_id
F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx),
F.boolean_mask(part.ndata['orig_id'], idx))
etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID])
for etype_id, etype in enumerate(g.etypes):
idx = etype_ids == etype_id
F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx),
F.boolean_mask(part.edata['orig_id'], idx))
for src_type, etype, dst_type in block.canonical_etypes: for src_type, etype, dst_type in block.canonical_etypes:
src, dst = block.edges(etype=etype) src, dst = block.edges(etype=etype)
# These are global Ids after shuffling. # These are global Ids after shuffling.
...@@ -478,8 +454,8 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server, etype_sorted=Fal ...@@ -478,8 +454,8 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server, etype_sorted=Fal
num_parts = num_server num_parts = num_server
num_hops = 1 num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir, orig_nid_map, orig_eid_map = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True) num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True)
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
...@@ -502,21 +478,6 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server, etype_sorted=Fal ...@@ -502,21 +478,6 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server, etype_sorted=Fal
src, dst = block.edges(etype=('n2', 'r23', 'n3')) src, dst = block.edges(etype=('n2', 'r23', 'n3'))
assert len(src) == 18 assert len(src) == 18
orig_nid_map = {ntype: F.zeros((g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes}
orig_eid_map = {etype: F.zeros((g.number_of_edges(etype),), dtype=F.int64) for etype in g.etypes}
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID])
for ntype_id, ntype in enumerate(g.ntypes):
idx = ntype_ids == ntype_id
F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx),
F.boolean_mask(part.ndata['orig_id'], idx))
etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID])
for etype_id, etype in enumerate(g.etypes):
idx = etype_ids == etype_id
F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx),
F.boolean_mask(part.edata['orig_id'], idx))
for src_type, etype, dst_type in block.canonical_etypes: for src_type, etype, dst_type in block.canonical_etypes:
src, dst = block.edges(etype=etype) src, dst = block.edges(etype=etype)
# These are global Ids after shuffling. # These are global Ids after shuffling.
...@@ -666,7 +627,7 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): ...@@ -666,7 +627,7 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
num_parts = num_server num_parts = num_server
num_hops = 1 num_hops = 1
orig_nids, _ = partition_graph(g, 'test_sampling', num_parts, tmpdir, orig_nid_map, orig_eid_map = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True) num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True)
pserver_list = [] pserver_list = []
...@@ -678,7 +639,7 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): ...@@ -678,7 +639,7 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
deg = get_degrees(g, orig_nids['game'], 'game') deg = get_degrees(g, orig_nid_map['game'], 'game')
nids = F.nonzero_1d(deg > 0) nids = F.nonzero_1d(deg > 0)
block, gpb = start_bipartite_sample_client(0, tmpdir, num_server > 1, block, gpb = start_bipartite_sample_client(0, tmpdir, num_server > 1,
nodes={'game': nids, 'user': [0]}) nodes={'game': nids, 'user': [0]})
...@@ -686,24 +647,6 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): ...@@ -686,24 +647,6 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
for p in pserver_list: for p in pserver_list:
p.join() p.join()
orig_nid_map = {ntype: F.zeros(
(g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes}
orig_eid_map = {etype: F.zeros(
(g.number_of_edges(etype),), dtype=F.int64) for etype in g.etypes}
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(
tmpdir / 'test_sampling.json', i)
ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID])
for ntype_id, ntype in enumerate(g.ntypes):
idx = ntype_ids == ntype_id
F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx),
F.boolean_mask(part.ndata['orig_id'], idx))
etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID])
for etype_id, etype in enumerate(g.etypes):
idx = etype_ids == etype_id
F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx),
F.boolean_mask(part.edata['orig_id'], idx))
for src_type, etype, dst_type in block.canonical_etypes: for src_type, etype, dst_type in block.canonical_etypes:
src, dst = block.edges(etype=etype) src, dst = block.edges(etype=etype)
# These are global Ids after shuffling. # These are global Ids after shuffling.
...@@ -767,7 +710,7 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): ...@@ -767,7 +710,7 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
num_parts = num_server num_parts = num_server
num_hops = 1 num_hops = 1
orig_nids, _ = partition_graph(g, 'test_sampling', num_parts, tmpdir, orig_nid_map, orig_eid_map = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True) num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True)
pserver_list = [] pserver_list = []
...@@ -780,7 +723,7 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): ...@@ -780,7 +723,7 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
pserver_list.append(p) pserver_list.append(p)
fanout = 3 fanout = 3
deg = get_degrees(g, orig_nids['game'], 'game') deg = get_degrees(g, orig_nid_map['game'], 'game')
nids = F.nonzero_1d(deg > 0) nids = F.nonzero_1d(deg > 0)
block, gpb = start_bipartite_etype_sample_client(0, tmpdir, num_server > 1, fanout, block, gpb = start_bipartite_etype_sample_client(0, tmpdir, num_server > 1, fanout,
nodes={'game': nids, 'user': [0]}) nodes={'game': nids, 'user': [0]})
...@@ -788,24 +731,6 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): ...@@ -788,24 +731,6 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
for p in pserver_list: for p in pserver_list:
p.join() p.join()
orig_nid_map = {ntype: F.zeros(
(g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes}
orig_eid_map = {etype: F.zeros(
(g.number_of_edges(etype),), dtype=F.int64) for etype in g.etypes}
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(
tmpdir / 'test_sampling.json', i)
ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID])
for ntype_id, ntype in enumerate(g.ntypes):
idx = ntype_ids == ntype_id
F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx),
F.boolean_mask(part.ndata['orig_id'], idx))
etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID])
for etype_id, etype in enumerate(g.etypes):
idx = etype_ids == etype_id
F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx),
F.boolean_mask(part.edata['orig_id'], idx))
for src_type, etype, dst_type in block.canonical_etypes: for src_type, etype, dst_type in block.canonical_etypes:
src, dst = block.edges(etype=etype) src, dst = block.edges(etype=etype)
# These are global Ids after shuffling. # These are global Ids after shuffling.
...@@ -943,8 +868,8 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server): ...@@ -943,8 +868,8 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
g.readonly() g.readonly()
num_parts = num_server num_parts = num_server
partition_graph(g, 'test_in_subgraph', num_parts, tmpdir, orig_nid, orig_eid = partition_graph(g, 'test_in_subgraph', num_parts, tmpdir,
num_hops=1, part_method='metis', reshuffle=True) num_hops=1, part_method='metis', reshuffle=True, return_mapping=True)
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
...@@ -959,14 +884,6 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server): ...@@ -959,14 +884,6 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
for p in pserver_list: for p in pserver_list:
p.join() p.join()
orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_in_subgraph.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
src, dst = sampled_graph.edges() src, dst = sampled_graph.edges()
src = orig_nid[src] src = orig_nid[src]
dst = orig_nid[dst] dst = orig_nid[dst]
......
...@@ -10,6 +10,15 @@ from dgl import function as fn ...@@ -10,6 +10,15 @@ from dgl import function as fn
import backend as F import backend as F
import unittest import unittest
import tempfile import tempfile
from utils import reset_envs
from dgl.distributed.partition import RESERVED_FIELD_DTYPE
def _verify_partition_data_types(part_g):
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in part_g.ndata:
assert part_g.ndata[k].dtype == dtype
if k in part_g.edata:
assert part_g.edata[k].dtype == dtype
def _get_inner_node_mask(graph, ntype_id): def _get_inner_node_mask(graph, ntype_id):
if dgl.NTYPE in graph.ndata: if dgl.NTYPE in graph.ndata:
...@@ -78,21 +87,11 @@ def verify_hetero_graph(g, parts): ...@@ -78,21 +87,11 @@ def verify_hetero_graph(g, parts):
nids = {ntype:[] for ntype in g.ntypes} nids = {ntype:[] for ntype in g.ntypes}
eids = {etype:[] for etype in g.etypes} eids = {etype:[] for etype in g.etypes}
for part in parts: for part in parts:
src, dst, eid = part.edges(form='all') _, _, eid = part.edges(form='all')
orig_src = F.gather_row(part.ndata['orig_id'], src)
orig_dst = F.gather_row(part.ndata['orig_id'], dst)
orig_eid = F.gather_row(part.edata['orig_id'], eid)
etype_arr = F.gather_row(part.edata[dgl.ETYPE], eid) etype_arr = F.gather_row(part.edata[dgl.ETYPE], eid)
eid_type = F.gather_row(part.edata[dgl.EID], eid) eid_type = F.gather_row(part.edata[dgl.EID], eid)
for etype in g.etypes: for etype in g.etypes:
etype_id = g.get_etype_id(etype) etype_id = g.get_etype_id(etype)
src1 = F.boolean_mask(orig_src, etype_arr == etype_id)
dst1 = F.boolean_mask(orig_dst, etype_arr == etype_id)
eid1 = F.boolean_mask(orig_eid, etype_arr == etype_id)
exist = g.has_edges_between(src1, dst1, etype=etype)
assert np.all(F.asnumpy(exist))
eid2 = g.edge_ids(src1, dst1, etype=etype)
assert np.all(F.asnumpy(eid1 == eid2))
eids[etype].append(F.boolean_mask(eid_type, etype_arr == etype_id)) eids[etype].append(F.boolean_mask(eid_type, etype_arr == etype_id))
# Make sure edge Ids fall into a range. # Make sure edge Ids fall into a range.
inner_edge_mask = _get_inner_edge_mask(part, etype_id) inner_edge_mask = _get_inner_edge_mask(part, etype_id)
...@@ -119,7 +118,7 @@ def verify_hetero_graph(g, parts): ...@@ -119,7 +118,7 @@ def verify_hetero_graph(g, parts):
assert len(uniq_ids) == g.number_of_edges(etype) assert len(uniq_ids) == g.number_of_edges(etype)
# TODO(zhengda) this doesn't check 'part_id' # TODO(zhengda) this doesn't check 'part_id'
def verify_graph_feats(g, gpb, part, node_feats, edge_feats): def verify_graph_feats(g, gpb, part, node_feats, edge_feats, orig_nids, orig_eids):
for ntype in g.ntypes: for ntype in g.ntypes:
ntype_id = g.get_ntype_id(ntype) ntype_id = g.get_ntype_id(ntype)
inner_node_mask = _get_inner_node_mask(part, ntype_id) inner_node_mask = _get_inner_node_mask(part, ntype_id)
...@@ -129,7 +128,7 @@ def verify_graph_feats(g, gpb, part, node_feats, edge_feats): ...@@ -129,7 +128,7 @@ def verify_graph_feats(g, gpb, part, node_feats, edge_feats):
assert np.all(F.asnumpy(ntype_ids) == ntype_id) assert np.all(F.asnumpy(ntype_ids) == ntype_id)
assert np.all(F.asnumpy(partid) == gpb.partid) assert np.all(F.asnumpy(partid) == gpb.partid)
orig_id = F.boolean_mask(part.ndata['orig_id'], inner_node_mask) orig_id = orig_nids[ntype][inner_type_nids]
local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype)
for name in g.nodes[ntype].data: for name in g.nodes[ntype].data:
...@@ -148,7 +147,7 @@ def verify_graph_feats(g, gpb, part, node_feats, edge_feats): ...@@ -148,7 +147,7 @@ def verify_graph_feats(g, gpb, part, node_feats, edge_feats):
assert np.all(F.asnumpy(etype_ids) == etype_id) assert np.all(F.asnumpy(etype_ids) == etype_id)
assert np.all(F.asnumpy(partid) == gpb.partid) assert np.all(F.asnumpy(partid) == gpb.partid)
orig_id = F.boolean_mask(part.edata['orig_id'], inner_edge_mask) orig_id = orig_eids[etype][inner_type_eids]
local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype)
for name in g.edges[etype].data: for name in g.edges[etype].data:
...@@ -180,6 +179,7 @@ def check_hetero_partition(hg, part_method, num_parts=4, num_trainers_per_machin ...@@ -180,6 +179,7 @@ def check_hetero_partition(hg, part_method, num_parts=4, num_trainers_per_machin
for i in range(num_parts): for i in range(num_parts):
part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition( part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition(
'/tmp/partition/test.json', i, load_feats=load_feats) '/tmp/partition/test.json', i, load_feats=load_feats)
_verify_partition_data_types(part_g)
if not load_feats: if not load_feats:
assert not node_feats assert not node_feats
assert not edge_feats assert not edge_feats
...@@ -225,7 +225,7 @@ def check_hetero_partition(hg, part_method, num_parts=4, num_trainers_per_machin ...@@ -225,7 +225,7 @@ def check_hetero_partition(hg, part_method, num_parts=4, num_trainers_per_machin
assert len(orig_eids1) == len(orig_eids2) assert len(orig_eids1) == len(orig_eids2)
assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))
parts.append(part_g) parts.append(part_g)
verify_graph_feats(hg, gpb, part_g, node_feats, edge_feats) verify_graph_feats(hg, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids)
shuffled_labels.append(node_feats['n1/labels']) shuffled_labels.append(node_feats['n1/labels'])
shuffled_elabels.append(edge_feats['r1/labels']) shuffled_elabels.append(edge_feats['r1/labels'])
...@@ -257,6 +257,7 @@ def check_partition(g, part_method, reshuffle, num_parts=4, num_trainers_per_mac ...@@ -257,6 +257,7 @@ def check_partition(g, part_method, reshuffle, num_parts=4, num_trainers_per_mac
for i in range(num_parts): for i in range(num_parts):
part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition( part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition(
'/tmp/partition/test.json', i, load_feats=load_feats) '/tmp/partition/test.json', i, load_feats=load_feats)
_verify_partition_data_types(part_g)
if not load_feats: if not load_feats:
assert not node_feats assert not node_feats
assert not edge_feats assert not edge_feats
...@@ -323,11 +324,12 @@ def check_partition(g, part_method, reshuffle, num_parts=4, num_trainers_per_mac ...@@ -323,11 +324,12 @@ def check_partition(g, part_method, reshuffle, num_parts=4, num_trainers_per_mac
assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))
if reshuffle: if reshuffle:
part_g.ndata['feats'] = F.gather_row(g.ndata['feats'], part_g.ndata['orig_id']) local_orig_nids = orig_nids[part_g.ndata[dgl.NID]]
part_g.edata['feats'] = F.gather_row(g.edata['feats'], part_g.edata['orig_id']) local_orig_eids = orig_eids[part_g.edata[dgl.EID]]
# when we read node data from the original global graph, we should use orig_id. part_g.ndata['feats'] = F.gather_row(g.ndata['feats'], local_orig_nids)
local_nodes = F.boolean_mask(part_g.ndata['orig_id'], part_g.ndata['inner_node']) part_g.edata['feats'] = F.gather_row(g.edata['feats'], local_orig_eids)
local_edges = F.boolean_mask(part_g.edata['orig_id'], part_g.edata['inner_edge']) local_nodes = orig_nids[local_nodes]
local_edges = orig_eids[local_edges]
else: else:
part_g.ndata['feats'] = F.gather_row(g.ndata['feats'], part_g.ndata[dgl.NID]) part_g.ndata['feats'] = F.gather_row(g.ndata['feats'], part_g.ndata[dgl.NID])
part_g.edata['feats'] = F.gather_row(g.edata['feats'], part_g.edata[dgl.NID]) part_g.edata['feats'] = F.gather_row(g.edata['feats'], part_g.edata[dgl.NID])
...@@ -403,6 +405,7 @@ def check_hetero_partition_single_etype(num_trainers): ...@@ -403,6 +405,7 @@ def check_hetero_partition_single_etype(num_trainers):
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_partition(): def test_partition():
os.environ['DGL_DIST_DEBUG'] = '1'
g = create_random_graph(1000) g = create_random_graph(1000)
check_partition(g, 'metis', False) check_partition(g, 'metis', False)
check_partition(g, 'metis', True) check_partition(g, 'metis', True)
...@@ -411,10 +414,12 @@ def test_partition(): ...@@ -411,10 +414,12 @@ def test_partition():
check_partition(g, 'random', False) check_partition(g, 'random', False)
check_partition(g, 'random', True) check_partition(g, 'random', True)
check_partition(g, 'metis', True, 4, 8, load_feats=False) check_partition(g, 'metis', True, 4, 8, load_feats=False)
reset_envs()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_hetero_partition(): def test_hetero_partition():
os.environ['DGL_DIST_DEBUG'] = '1'
check_hetero_partition_single_etype(1) check_hetero_partition_single_etype(1)
check_hetero_partition_single_etype(4) check_hetero_partition_single_etype(4)
hg = create_random_hetero() hg = create_random_hetero()
...@@ -423,6 +428,7 @@ def test_hetero_partition(): ...@@ -423,6 +428,7 @@ def test_hetero_partition():
check_hetero_partition(hg, 'metis', 4, 8) check_hetero_partition(hg, 'metis', 4, 8)
check_hetero_partition(hg, 'random') check_hetero_partition(hg, 'random')
check_hetero_partition(hg, 'metis', 4, 8, load_feats=False) check_hetero_partition(hg, 'metis', 4, 8, load_feats=False)
reset_envs()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_BasicPartitionBook(): def test_BasicPartitionBook():
......
...@@ -52,6 +52,7 @@ def reset_envs(): ...@@ -52,6 +52,7 @@ def reset_envs():
"DGL_DIST_MODE", "DGL_DIST_MODE",
"DGL_NUM_CLIENT", "DGL_NUM_CLIENT",
"DGL_DIST_MAX_TRY_TIMES", "DGL_DIST_MAX_TRY_TIMES",
"DGL_DIST_DEBUG",
]: ]:
if key in os.environ: if key in os.environ:
os.environ.pop(key) os.environ.pop(key)
......
...@@ -11,6 +11,15 @@ from create_chunked_dataset import create_chunked_dataset ...@@ -11,6 +11,15 @@ from create_chunked_dataset import create_chunked_dataset
import dgl import dgl
from dgl.data.utils import load_graphs, load_tensors from dgl.data.utils import load_graphs, load_tensors
from dgl.distributed.partition import RESERVED_FIELD_DTYPE
def _verify_partition_data_types(part_g):
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in part_g.ndata:
assert part_g.ndata[k].dtype == dtype
if k in part_g.edata:
assert part_g.edata[k].dtype == dtype
@pytest.mark.parametrize("num_chunks", [1, 8]) @pytest.mark.parametrize("num_chunks", [1, 8])
...@@ -137,6 +146,7 @@ def test_part_pipeline(num_chunks, num_parts): ...@@ -137,6 +146,7 @@ def test_part_pipeline(num_chunks, num_parts):
cmd += f" --partitions-dir {partition_dir}" cmd += f" --partitions-dir {partition_dir}"
cmd += f" --out-dir {out_dir}" cmd += f" --out-dir {out_dir}"
cmd += f" --ip-config {ip_config}" cmd += f" --ip-config {ip_config}"
cmd += " --ssh-port 22"
cmd += " --process-group-timeout 60" cmd += " --process-group-timeout 60"
cmd += " --save-orig-nids" cmd += " --save-orig-nids"
cmd += " --save-orig-eids" cmd += " --save-orig-eids"
...@@ -223,6 +233,7 @@ def test_part_pipeline(num_chunks, num_parts): ...@@ -223,6 +233,7 @@ def test_part_pipeline(num_chunks, num_parts):
g_list, data_dict = load_graphs(fname) g_list, data_dict = load_graphs(fname)
part_g = g_list[0] part_g = g_list[0]
assert isinstance(part_g, dgl.DGLGraph) assert isinstance(part_g, dgl.DGLGraph)
_verify_partition_data_types(part_g)
# node_feat.dgl # node_feat.dgl
fname = os.path.join(sub_dir, "node_feat.dgl") fname = os.path.join(sub_dir, "node_feat.dgl")
......
...@@ -14,6 +14,7 @@ from pyarrow import csv ...@@ -14,6 +14,7 @@ from pyarrow import csv
import constants import constants
from utils import get_idranges, memory_snapshot, read_json from utils import get_idranges, memory_snapshot, read_json
from dgl.distributed.partition import RESERVED_FIELD_DTYPE
def create_dgl_object(schema, part_id, node_data, edge_data, edgeid_offset, def create_dgl_object(schema, part_id, node_data, edge_data, edgeid_offset,
...@@ -232,7 +233,7 @@ def create_dgl_object(schema, part_id, node_data, edge_data, edgeid_offset, ...@@ -232,7 +233,7 @@ def create_dgl_object(schema, part_id, node_data, edge_data, edgeid_offset,
2. Once the map is created, use this map to map all the node-ids in the part_local_src_id 2. Once the map is created, use this map to map all the node-ids in the part_local_src_id
and part_local_dst_id list to their appropriate `new` node-ids (post-reshuffle order). and part_local_dst_id list to their appropriate `new` node-ids (post-reshuffle order).
3. Since only the node's order is changed, we will have to re-order nodes related information when 3. Since only the node's order is changed, we will have to re-order nodes related information when
creating dgl object: this includes orig_id, dgl.NTYPE, dgl.NID and inner_node. creating dgl object: this includes dgl.NTYPE, dgl.NID and inner_node.
4. Edge's order is not changed. At this point in the execution path edges are still ordered by their etype-ids. 4. Edge's order is not changed. At this point in the execution path edges are still ordered by their etype-ids.
5. Create the dgl object appropriately and return the dgl object. 5. Create the dgl object appropriately and return the dgl object.
...@@ -272,9 +273,9 @@ def create_dgl_object(schema, part_id, node_data, edge_data, edgeid_offset, ...@@ -272,9 +273,9 @@ def create_dgl_object(schema, part_id, node_data, edge_data, edgeid_offset,
part_graph = dgl.graph(data=(part_local_src_id, part_local_dst_id), num_nodes=len(uniq_ids)) part_graph = dgl.graph(data=(part_local_src_id, part_local_dst_id), num_nodes=len(uniq_ids))
part_graph.edata[dgl.EID] = th.arange( part_graph.edata[dgl.EID] = th.arange(
edgeid_offset, edgeid_offset + part_graph.number_of_edges(), dtype=th.int64) edgeid_offset, edgeid_offset + part_graph.number_of_edges(), dtype=th.int64)
part_graph.edata['orig_id'] = th.as_tensor(global_edge_id) part_graph.edata[dgl.ETYPE] = th.as_tensor(etype_ids, dtype=RESERVED_FIELD_DTYPE[dgl.ETYPE])
part_graph.edata[dgl.ETYPE] = th.as_tensor(etype_ids) part_graph.edata['inner_edge'] = th.ones(part_graph.number_of_edges(),
part_graph.edata['inner_edge'] = th.ones(part_graph.number_of_edges(), dtype=th.bool) dtype=RESERVED_FIELD_DTYPE['inner_edge'])
#compute per_type_ids and ntype for all the nodes in the graph. #compute per_type_ids and ntype for all the nodes in the graph.
...@@ -285,10 +286,10 @@ def create_dgl_object(schema, part_id, node_data, edge_data, edgeid_offset, ...@@ -285,10 +286,10 @@ def create_dgl_object(schema, part_id, node_data, edge_data, edgeid_offset,
ntype, per_type_ids = id_map(part_global_ids) ntype, per_type_ids = id_map(part_global_ids)
#continue with the graph creation #continue with the graph creation
part_graph.ndata['orig_id'] = th.as_tensor(per_type_ids) part_graph.ndata[dgl.NTYPE] = th.as_tensor(ntype, dtype=RESERVED_FIELD_DTYPE[dgl.NTYPE])
part_graph.ndata[dgl.NTYPE] = th.as_tensor(ntype)
part_graph.ndata[dgl.NID] = th.as_tensor(uniq_ids[reshuffle_nodes]) part_graph.ndata[dgl.NID] = th.as_tensor(uniq_ids[reshuffle_nodes])
part_graph.ndata['inner_node'] = inner_nodes[reshuffle_nodes] part_graph.ndata['inner_node'] = th.as_tensor(inner_nodes[reshuffle_nodes],
dtype=RESERVED_FIELD_DTYPE['inner_node'])
orig_nids = None orig_nids = None
orig_eids = None orig_eids = None
......
...@@ -390,7 +390,7 @@ def write_graph_dgl(graph_file, graph_obj): ...@@ -390,7 +390,7 @@ def write_graph_dgl(graph_file, graph_obj):
graph_file : string graph_file : string
File name in which graph object is serialized File name in which graph object is serialized
""" """
dgl.save_graphs(graph_file, [graph_obj]) dgl.distributed.partition._save_graphs(graph_file, [graph_obj])
def write_dgl_objects(graph_obj, node_features, edge_features, def write_dgl_objects(graph_obj, node_features, edge_features,
output_dir, part_id, orig_nids, orig_eids): output_dir, part_id, orig_nids, orig_eids):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment