Unverified Commit 2ae40af5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: rupei/datagen cleanups (#1394)

parent a6887fa0
...@@ -19,6 +19,30 @@ from data_generator.protocols import CACHE_END, END_NODE, SUPER_ROOT ...@@ -19,6 +19,30 @@ from data_generator.protocols import CACHE_END, END_NODE, SUPER_ROOT
from data_generator.sampler import get_cdf from data_generator.sampler import get_cdf
def _verify_tree(G: nx.DiGraph) -> None:
invalid_nodes = [(node, d) for node, d in G.in_degree() if d > 1]
if invalid_nodes:
print("ERROR: The following nodes have multiple parents (in-degree > 1):")
for node, in_degree in invalid_nodes:
parents = list(G.predecessors(node))
print(f" Node {node}: in-degree={in_degree}, parents={parents}")
raise ValueError(
"Graph is not a valid tree: nodes with multiple parents detected"
)
def _mark_visited(G: nx.DiGraph) -> None:
# visits to leaf nodes (non-core branches) are considered as ended
for node in G.nodes():
if "to_leaf" not in G.nodes[node]:
G.nodes[node]["to_leaf"] = 0
if G.nodes[node]["visited"] <= 1:
continue
for child in G.successors(node):
if G.nodes[child]["visited"] == 1:
G.nodes[node]["to_leaf"] += 1
def _merge_chains(G: nx.DiGraph) -> nx.DiGraph: def _merge_chains(G: nx.DiGraph) -> nx.DiGraph:
""" """
Make the graph radix-like (meaning all unary paths are contracted). Make the graph radix-like (meaning all unary paths are contracted).
...@@ -73,7 +97,9 @@ def _merge_chains(G: nx.DiGraph) -> nx.DiGraph: ...@@ -73,7 +97,9 @@ def _merge_chains(G: nx.DiGraph) -> nx.DiGraph:
chain_len += 1 chain_len += 1
succ = list(G.successors(end_node)) succ = list(G.successors(end_node))
G.add_edge(node_pred, end_node, weight=weight) G.add_edge(
node_pred, end_node, weight=weight
) # may overwrite the edge (should be harmless)
G.nodes[end_node]["length"] = chain_len G.nodes[end_node]["length"] = chain_len
G.remove_nodes_from(nodes_rm) G.remove_nodes_from(nodes_rm)
......
...@@ -56,7 +56,6 @@ def texts_to_hashes( ...@@ -56,7 +56,6 @@ def texts_to_hashes(
parent_hash = 0 parent_hash = 0
hashes: List[int] = [] hashes: List[int] = []
print(blocks)
for block in blocks: for block in blocks:
combined = (parent_hash, hash(tuple(block))) combined = (parent_hash, hash(tuple(block)))
global_hash = hash(combined) global_hash = hash(combined)
......
...@@ -21,9 +21,11 @@ import networkx as nx ...@@ -21,9 +21,11 @@ import networkx as nx
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from data_generator.graph_utils import ( from data_generator.graph_utils import (
_mark_visited,
_merge_chains, _merge_chains,
_precompute_transition_cdfs, _precompute_transition_cdfs,
_remove_leaves, _remove_leaves,
_verify_tree,
) )
from data_generator.protocols import CACHE_END, END_NODE, SUPER_ROOT from data_generator.protocols import CACHE_END, END_NODE, SUPER_ROOT
from data_generator.sampler import EmpiricalSampler, sample_from_cdf from data_generator.sampler import EmpiricalSampler, sample_from_cdf
...@@ -87,7 +89,7 @@ class Synthesizer: ...@@ -87,7 +89,7 @@ class Synthesizer:
assert ( assert (
isinstance(self.prefix_len_multiplier, float) isinstance(self.prefix_len_multiplier, float)
and self.prefix_len_multiplier > 0 and self.prefix_len_multiplier > 0
), "context_len_multiplier must be a positive float" ), "prefix_len_multiplier must be a positive float"
assert ( assert (
isinstance(self.prompt_len_multiplier, float) isinstance(self.prompt_len_multiplier, float)
and self.prompt_len_multiplier > 0 and self.prompt_len_multiplier > 0
...@@ -134,28 +136,9 @@ class Synthesizer: ...@@ -134,28 +136,9 @@ class Synthesizer:
self.G.nodes[SUPER_ROOT]["visited"] = num_paths self.G.nodes[SUPER_ROOT]["visited"] = num_paths
self.max_hash_id = max_hash_id self.max_hash_id = max_hash_id
invalid_nodes = [(node, d) for node, d in self.G.in_degree() if d > 1] _verify_tree(self.G)
if invalid_nodes: _mark_visited(self.G)
print("ERROR: The following nodes have multiple parents (in-degree > 1):") self.G = _merge_chains(self.G) # make graph radix-like
for node, in_degree in invalid_nodes:
parents = list(self.G.predecessors(node))
print(f" Node {node}: in-degree={in_degree}, parents={parents}")
raise ValueError(
"Graph is not a valid tree: nodes with multiple parents detected"
)
# visits to leaf nodes (non-core branches) are considered as ended
for node in self.G.nodes():
if "to_leaf" not in self.G.nodes[node]:
self.G.nodes[node]["to_leaf"] = 0
if self.G.nodes[node]["visited"] <= 1:
continue
for child in self.G.successors(node):
if self.G.nodes[child]["visited"] == 1:
self.G.nodes[node]["to_leaf"] += 1
# make graph radix-like
self.G = _merge_chains(self.G)
self.G, leaves_lens = _remove_leaves(self.G) self.G, leaves_lens = _remove_leaves(self.G)
# Apply prompt_len_multiplier to leaves_lens # Apply prompt_len_multiplier to leaves_lens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment