Unverified Commit 32dc1af6 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] generate partition meta for ParMETIS pipeline (#5020)

* [Dist] generate partition meta for ParMETIS
parent f4f9abc8
...@@ -12,6 +12,7 @@ from chunk_graph import chunk_graph ...@@ -12,6 +12,7 @@ from chunk_graph import chunk_graph
from dgl.data.utils import load_graphs, load_tensors from dgl.data.utils import load_graphs, load_tensors
from create_chunked_dataset import create_chunked_dataset from create_chunked_dataset import create_chunked_dataset
from partition_algo.base import load_partition_meta
""" """
TODO: skipping this test case since the dependency, mpirun, is TODO: skipping this test case since the dependency, mpirun, is
...@@ -162,6 +163,13 @@ def test_parmetis_postprocessing(): ...@@ -162,6 +163,13 @@ def test_parmetis_postprocessing():
assert np.min(part_ids) == 0 assert np.min(part_ids) == 0
assert np.max(part_ids) == 1 assert np.max(part_ids) == 1
# check partition meta file
part_meta_file = os.path.join(results_dir, "partition_meta.json")
assert os.path.isfile(part_meta_file)
part_meta = load_partition_meta(part_meta_file)
assert part_meta.num_parts == 2
assert part_meta.algo_name == "metis"
""" """
TODO: skipping this test case since it depends on the dependency, mpi, TODO: skipping this test case since it depends on the dependency, mpi,
......
...@@ -11,6 +11,7 @@ import pyarrow.csv as csv ...@@ -11,6 +11,7 @@ import pyarrow.csv as csv
import constants import constants
from utils import get_idranges, get_node_types, read_json from utils import get_idranges, get_node_types, read_json
from partition_algo.base import PartitionMeta, dump_partition_meta
def post_process(params): def post_process(params):
...@@ -39,6 +40,7 @@ def post_process(params): ...@@ -39,6 +40,7 @@ def post_process(params):
) )
global_nids = metis_df["f0"].to_numpy() global_nids = metis_df["f0"].to_numpy()
partition_ids = metis_df["f1"].to_numpy() partition_ids = metis_df["f1"].to_numpy()
num_parts = np.unique(partition_ids).size
sort_idx = np.argsort(global_nids) sort_idx = np.argsort(global_nids)
global_nids = global_nids[sort_idx] global_nids = global_nids[sort_idx]
...@@ -66,6 +68,13 @@ def post_process(params): ...@@ -66,6 +68,13 @@ def post_process(params):
options, options,
) )
logging.info(f"Generated {out_file}") logging.info(f"Generated {out_file}")
# generate partition meta file.
part_meta = PartitionMeta(
version="1.0.0", num_parts=num_parts, algo_name="metis"
)
dump_partition_meta(part_meta, os.path.join(outdir, "partition_meta.json"))
logging.info("Done processing parmetis output") logging.info("Done processing parmetis output")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment