Unverified Commit 2ae40af5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: rupei/datagen cleanups (#1394)

parent a6887fa0
......@@ -19,6 +19,30 @@ from data_generator.protocols import CACHE_END, END_NODE, SUPER_ROOT
from data_generator.sampler import get_cdf
def _verify_tree(G: nx.DiGraph) -> None:
invalid_nodes = [(node, d) for node, d in G.in_degree() if d > 1]
if invalid_nodes:
print("ERROR: The following nodes have multiple parents (in-degree > 1):")
for node, in_degree in invalid_nodes:
parents = list(G.predecessors(node))
print(f" Node {node}: in-degree={in_degree}, parents={parents}")
raise ValueError(
"Graph is not a valid tree: nodes with multiple parents detected"
)
def _mark_visited(G: nx.DiGraph) -> None:
# visits to leaf nodes (non-core branches) are considered as ended
for node in G.nodes():
if "to_leaf" not in G.nodes[node]:
G.nodes[node]["to_leaf"] = 0
if G.nodes[node]["visited"] <= 1:
continue
for child in G.successors(node):
if G.nodes[child]["visited"] == 1:
G.nodes[node]["to_leaf"] += 1
def _merge_chains(G: nx.DiGraph) -> nx.DiGraph:
"""
Make the graph radix-like (meaning all unary paths are contracted).
......@@ -73,7 +97,9 @@ def _merge_chains(G: nx.DiGraph) -> nx.DiGraph:
chain_len += 1
succ = list(G.successors(end_node))
G.add_edge(node_pred, end_node, weight=weight)
G.add_edge(
node_pred, end_node, weight=weight
) # may overwrite the edge (should be harmless)
G.nodes[end_node]["length"] = chain_len
G.remove_nodes_from(nodes_rm)
......
......@@ -56,7 +56,6 @@ def texts_to_hashes(
parent_hash = 0
hashes: List[int] = []
print(blocks)
for block in blocks:
combined = (parent_hash, hash(tuple(block)))
global_hash = hash(combined)
......
......@@ -21,9 +21,11 @@ import networkx as nx
import numpy as np
import pandas as pd
from data_generator.graph_utils import (
_mark_visited,
_merge_chains,
_precompute_transition_cdfs,
_remove_leaves,
_verify_tree,
)
from data_generator.protocols import CACHE_END, END_NODE, SUPER_ROOT
from data_generator.sampler import EmpiricalSampler, sample_from_cdf
......@@ -87,7 +89,7 @@ class Synthesizer:
assert (
isinstance(self.prefix_len_multiplier, float)
and self.prefix_len_multiplier > 0
), "context_len_multiplier must be a positive float"
), "prefix_len_multiplier must be a positive float"
assert (
isinstance(self.prompt_len_multiplier, float)
and self.prompt_len_multiplier > 0
......@@ -134,28 +136,9 @@ class Synthesizer:
self.G.nodes[SUPER_ROOT]["visited"] = num_paths
self.max_hash_id = max_hash_id
invalid_nodes = [(node, d) for node, d in self.G.in_degree() if d > 1]
if invalid_nodes:
print("ERROR: The following nodes have multiple parents (in-degree > 1):")
for node, in_degree in invalid_nodes:
parents = list(self.G.predecessors(node))
print(f" Node {node}: in-degree={in_degree}, parents={parents}")
raise ValueError(
"Graph is not a valid tree: nodes with multiple parents detected"
)
# visits to leaf nodes (non-core branches) are considered as ended
for node in self.G.nodes():
if "to_leaf" not in self.G.nodes[node]:
self.G.nodes[node]["to_leaf"] = 0
if self.G.nodes[node]["visited"] <= 1:
continue
for child in self.G.successors(node):
if self.G.nodes[child]["visited"] == 1:
self.G.nodes[node]["to_leaf"] += 1
# make graph radix-like
self.G = _merge_chains(self.G)
_verify_tree(self.G)
_mark_visited(self.G)
self.G = _merge_chains(self.G) # make graph radix-like
self.G, leaves_lens = _remove_leaves(self.G)
# Apply prompt_len_multiplier to leaves_lens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment