Unverified Commit 98b9e0fa authored by kylasa's avatar kylasa Committed by GitHub
Browse files

[Dist] Create <graph_name>_stats.txt file if it does not exist before ParMETIS execution (#4791)

* check if stats file exists, if not create one before parmetis run

* correct the typo error and correctly use constants.GRAPH_NAME
parent 53117c51
...@@ -287,6 +287,36 @@ def gen_parmetis_input_args(params, schema_map): ...@@ -287,6 +287,36 @@ def gen_parmetis_input_args(params, schema_map):
schema_map[constants.STR_NUM_NODES_PER_CHUNK], schema_map[constants.STR_NUM_NODES_PER_CHUNK],
) )
# Check if <graph-name>_stats.txt exists, if not create one using metadata.
# Here stats file will be created in the current directory.
# No. of constraints, third column in the stats file is computed as follows:
# num_constraints = no. of node types + train_mask + test_mask + val_mask
# Here, (train/test/val) masks will be set to 1 if these masks exist for
# all the node types in the graph, otherwise these flags will be set to 0
assert constants.GRAPH_NAMEe in schema_map, "Graph name is not present in the json file"
graph_name = schema_map[constants.GRAPH_NAME]
if not os.path.isfile(f'{graph_name}_stats.txt'):
num_nodes = np.sum(np.concatenate(schema_map[constants.STR_NUM_NODES_PER_CHUNK]))
num_edges = np.sum(np.concatenate(schema_map[constants.STR_NUM_EDGES_PER_CHUNK]))
num_ntypes = len(schema_map[constants.STR_NODE_TYPE])
train_mask = test_mask = val_mask = 0
node_feats = schema_map[constants.STR_NODE_DATA]
for ntype, ntype_data in node_feats.items():
if "train_mask" in ntype_data:
train_mask += 1
if "test_mask" in ntype_data:
test_mask += 1
if "val_mask" in ntype_data:
val_mask += 1
train_mask = train_mask // num_ntypes
test_mask = test_mask // num_ntypes
val_mask = val_mask // num_ntypes
num_constraints = num_nyptes + train_mask + test_mask + val_mask
with open(f'{graph_name}_stats.txt', 'w') as sf:
sf.write(f'{num_nodes} {num_edges} {num_constraints}')
node_files = [] node_files = []
outdir = Path(params.output_dir) outdir = Path(params.output_dir)
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment