"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "3c7c1d64ce3994ccf247d164a3287b3f89e2278e"
Unverified Commit 9cdba76d authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: data synthesizer based on prefix statistics (#1087)


Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent a20445de
...@@ -54,7 +54,7 @@ jobs: ...@@ -54,7 +54,7 @@ jobs:
env: env:
PYTEST_MARKS: "pre_merge or mypy" PYTEST_MARKS: "pre_merge or mypy"
run: | run: |
docker run -w /workspace --name ${{ env.CONTAINER_ID }}_pytest ${{ steps.define_image_tag.outputs.image_tag }} pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m "${{ env.PYTEST_MARKS }}" docker run -w /workspace --name ${{ env.CONTAINER_ID }}_pytest ${{ steps.define_image_tag.outputs.image_tag }} bash -c "pip install -e /workspace/benchmarks && pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m \"${{ env.PYTEST_MARKS }}\""
- name: Copy test report from test Container - name: Copy test report from test Container
if: always() if: always()
run: | run: |
...@@ -77,4 +77,4 @@ jobs: ...@@ -77,4 +77,4 @@ jobs:
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
with: with:
name: Event File name: Event File
path: ${{ github.event_path }} path: ${{ github.event_path }}
\ No newline at end of file
<!-- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Benchmarks
This directory contains benchmarking scripts and tools for performance evaluation.
## Installation
To install the necessary dependencies locally, run:
```bash
pip install -e .
```
Currently, this will install lightweight tools for:
- Analyzing prefix-structured data (`datagen analyze`)
- Synthesizing structured data customizable for testing purposes (`datagen synthesize`)
\ No newline at end of file
<!-- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
## Trace File Format
The following tools help analyze and synthesize new data based on the [mooncake trace file format](https://github.com/kvcache-ai/Mooncake/blob/d21da178bae8db9651cf18a76824c084145fc725/mooncake_trace.jsonl). In this format, the first few lines would look like this, for example:
```
{"timestamp": 0, "input_length": 6755, "output_length": 500, "hash_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}
{"timestamp": 0, "input_length": 7319, "output_length": 490, "hash_ids": [0, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]}
{"timestamp": 3052, "input_length": 7234, "output_length": 794, "hash_ids": [0, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]}
{"timestamp": 3052, "input_length": 2287, "output_length": 316, "hash_ids": [0, 42, 43, 44, 45]}
```
**Hash ID Generation:** Each new hash ID is the next consecutive integer after the last one used. Two `hash_ids` sharing the same integers represents the prefix overlap. To generate these increasing hash IDs from a list of texts, we provide the `texts_to_hashes` function in `hasher.py`.
**Timestamp:** The arrival time (in milliseconds) of the request since the first request, which can be the same for multiple requests arriving simultaneously.
**Block Size and Hash IDs:** In this example, the `block_size` (the page size of the KV cache) is assumed to be 512. The length of the `hash_ids` array equals `input_length // block_size`.
## Prefix Analyzer
The Prefix Analyzer provides statistics on a trace file, such as Input Sequence Length (ISL), Output Sequence Length (OSL), and theoretical cache hit rate.
It is useful for understanding the structure and reuse patterns in your dataset.
```bash
datagen analyze --input-file <path_to_trace.jsonl> --block-size <block_size>
```
- `--input-file`: Path to your trace file in jsonl format (default: `mooncake_trace.jsonl`)
- `--block-size`: Block size for prefix calculation (default: 512)
The script will print out summary statistics for ISL, OSL, user prompt lengths, and the theoretical cache hit rate (assuming an infinite cache).
## Synthesizer
The Synthesizer goes a step further:
It builds a prefix tree from the original trace file, extracts prefix statistics, and generates a new synthetic dataset based on these statistics.
You can control various aspects of the synthetic data generation with tunable knobs, such as request rate, context/prompt length multipliers, and the number of tree copies.
This is useful for generating large, realistic synthetic traces for benchmarking or simulation, while preserving the structural properties of the original dataset.
### How to run
```bash
datagen synthesize --input-file <path_to_trace.jsonl> --num-requests <N> [other options...]
```
**Options:**
- `--input-file`: Path to the input trace file (default: `mooncake_trace.jsonl`)
- `--num-requests`: Number of requests to synthesize (default: 100000)
- `--speedup-ratio`: Factor to speed up request intervals. It effectively divides the synthetic timestamps by this value (default: 1)
- `--prefix-len-multiplier`: Multiplier for prefix lengths (default: 1.0)
- `--prefix-root-multiplier`: Number of times to replicate the core radix tree (default: 1)
- `--prompt-len-multiplier`: Multiplier for leaf path lengths (default: 1.0, use <1 for shorter prompts)
- `--max-isl`: Maximum input sequence length to include in output (default: None, no filtering)
- `--block-size`: Block size for prefilling and decoding (default: 512)
- `--output-file`: Path to the output file (default: auto-generated from input file and options)
### Example
Say we only have these hash lists:
```
[0, 1, 2, (3)]
[0, 1]
[0, 1, 2]
[0, (4), (5)]
```
First, we identify the "core prefix nodes" as [0, 1, 2] since they are visited more than once. The nodes [3, 4, 5] would be considered "user prompts" as they only appear once (noted in brackets).
If we set the `prefix-len-multiplier` to 2, then the core prefix branches will be stretched, effectively giving:
```
[0, 1, 2, 3, 4, 5, (6)]
[0, 1, 2, 3]
[0, 1, 2, 3, 4, 5]
[0, 1, (7), (8)]
```
Note that the "prompt branches" are not stretched by `prefix-len-multiplier`. They can be separately modified by applying `prompt-len-multiplier`.
Now, if we set `prefix-root-multiplier` to 2, then each row will have a 50 percent chance of being incremented by a large integer, so that they will be effectively separated into a new radix tree, which matches the statistics of the original one, but having completely different roots.
For example, if rows 2 and 4 are offseted, then we would get:
```
[0, 1, 2, 3, 4, 5, (6)]
[10, 11, 12, 13]
[0, 1, 2, 3, 4, 5]
[10, 11, (14), (15)]
```
### Implementation details
The generation algorithm, simplified, is as follows
- Store the hash ids in a directed tree structure (prefix tree)
- Each directed edge `weight` indicates how many times the edge is traversed, which is needed to compute transition probabilities.
- Contract unary paths (chains) in the tree so that it is in a radix-tree form, meaning every node that is the only child will be contracted with the parent. As a consequence, each node need to store an attribute `length` to indicate the compressed length (1 if no compression). The depth multiplier scales this compressed length (rounded to the nearest integer), effectively increasing the length of each radix node.
- Identify every leaf node that is visited only once, and prune them from the tree, as they are highly likely not part of the core radix tree. In other words, we do not need to store nodes that are part of the actual user prompts.
- At this stage, each node will have (possibly zero) transition probabilities to a child prefix node, to a "user prompt" node, and to a "termination" node. Use these probabilities to sample a path in the core radix tree, the append the path with new hash ids corresponding to a user prompt of length sampled from the dataset. The width multiplier effectively duplicates the entire radix tree the specified number of times, each with a new set of hash ids, creating more diverse request patterns.
## Testing
To test for "correctness", or faithfulness to the original trace statistics, one can run
```
python -m benchmarks.data_utils.synthesizer \
--input-file mooncake_trace.jsonl \
--num-requests 500000 \
```
and compare the synthetic ISL statistics (mean, median, std) to the original ISL statistics, which one can obtain by running
```
python -m benchmarks.data_utils.prefix_analyzer \
--input-file mooncake_trace.jsonl \
```
I find this to be the most "robust" end-to-end test. It is important to sample a large number of requests (e.g., hundreds of thousands) to ensure the statistics are meaningful, due to the law of large numbers. In particular, the mean statistics (such as mean ISL) should be well preserved in the synthetic data. However, the standard deviation statistics—especially for ISL—are not expected to match exactly, since the synthesizer does not capture the correlation between context length and prompt length present in the original data.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from data_generator.cli import main as cli_main
def main():
cli_main()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import sys
def main():
parser = argparse.ArgumentParser(
description="Data generation and analysis tools for benchmarking",
prog="datagen",
)
# Add subparsers for commands
subparsers = parser.add_subparsers(dest="command", help="Command to run")
# Create the parser for the "analyze" command
subparsers.add_parser("analyze", help="Analyze data")
# Create the parser for the "synthesize" command
subparsers.add_parser("synthesize", help="Synthesize data")
args, remaining = parser.parse_known_args()
if args.command == "analyze":
# Import and run the analyzer main
from data_generator import prefix_analyzer
sys.argv = [sys.argv[0]] + remaining
prefix_analyzer.main()
elif args.command == "synthesize":
# Import and run the synthesizer main
from data_generator import synthesizer
sys.argv = [sys.argv[0]] + remaining
synthesizer.main()
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import networkx as nx
import numpy as np
from data_generator.protocols import CACHE_END, END_NODE, SUPER_ROOT
from data_generator.sampler import get_cdf
def _merge_chains(G: nx.DiGraph) -> nx.DiGraph:
"""
Make the graph radix-like (meaning all unary paths are contracted).
This function transforms a prefix tree into a radix tree structure by contracting
unary paths (chains of nodes with exactly one predecessor and one successor).
The resulting radix tree is significantly more compact than the original prefix tree,
as it eliminates redundant intermediate nodes while preserving the structural
information needed for path sampling.
This compression is particularly beneficial for efficient path sampling during data
synthesis. In addition, keep track of the contracted lengths in the 'length' attribute
of each node to preserve the original path information.
Args:
G (networkx.DiGraph): A directed graph representing a prefix tree structure.
Returns:
networkx.DiGraph: The resulting radix tree with unary paths contracted.
"""
for visited in sorted(np.unique([G.nodes[node]["visited"] for node in G.nodes()])):
sub_nodes = [node for node in G.nodes() if G.nodes[node]["visited"] == visited]
subgraph = G.subgraph(sub_nodes)
if len(subgraph) == 1:
continue
chain_nodes = [
node
for node in subgraph.nodes()
if G.in_degree(node) == 1 and G.out_degree(node) == 1
]
if not chain_nodes:
continue
chain_nodes = sorted(chain_nodes)
nodes_rm = []
for node in chain_nodes:
node_pred = list(G.predecessors(node))[0]
# find the parent node source
if G.nodes[node_pred]["visited"] == visited and node_pred != SUPER_ROOT:
continue
weight = G[node_pred][node]["weight"]
end_node = node
chain_len = 1
succ = list(G.successors(end_node))
# find the end of the chain
while succ and G.nodes[succ[0]]["visited"] == visited:
nodes_rm.append(end_node)
end_node = succ[0]
chain_len += 1
succ = list(G.successors(end_node))
G.add_edge(node_pred, end_node, weight=weight)
G.nodes[end_node]["length"] = chain_len
G.remove_nodes_from(nodes_rm)
for node in G.nodes():
if "length" not in G.nodes[node]:
G.nodes[node]["length"] = 1
return G
def _remove_leaves(G: nx.DiGraph) -> tuple[nx.DiGraph, list[int]]:
"""
Remove all nodes that are only visited once from the tree.
This function removes nodes representing unique user prompts (nodes with visited=1)
from the radix tree, leaving only the "core radix tree" structure that contains
commonly traversed paths. The removed nodes typically represent leaf paths that
were accessed only once and don't contribute to the core structural patterns.
Args:
G (networkx.DiGraph): A directed graph representing a radix tree structure.
Returns:
tuple[networkx.DiGraph, list[int]]: A tuple containing:
- The modified graph with unique nodes removed
- A list of lengths of the removed leaf nodes
"""
leaves = {
node: G.nodes[node]["length"]
for node in G.nodes()
if G.nodes[node]["visited"] == 1
}
leaves_id = list(leaves.keys())
leaves_len = list(leaves.values())
G.remove_nodes_from(leaves_id)
return G, leaves_len
def _precompute_transition_cdfs(G: nx.DiGraph) -> nx.DiGraph:
for node in G.nodes():
out_edges = list(G.out_edges(node))
weights = [G[edge[0]][edge[1]]["weight"] for edge in out_edges] + [
G.nodes[node]["to_leaf"],
G.nodes[node]["end"],
]
G.nodes[node]["out_cdf"] = get_cdf(weights)
G.nodes[node]["out_nodes"] = [edge[1] for edge in out_edges] + [
CACHE_END,
END_NODE,
]
return G
def _validate_graph(G: nx.DiGraph) -> bool:
for node in G.nodes():
# Skip nodes without parents or children
if G.in_degree(node) == 0 or G.out_degree(node) == 0:
continue
# Get incoming edge weight (should only be one parent)
parent = list(G.predecessors(node))[0]
in_weight = G[parent][node]["weight"]
# Sum outgoing edge weights
out_weights = [G[node][child]["weight"] for child in G.successors(node)]
out_weights += [G.nodes[node]["to_leaf"], G.nodes[node]["end"]]
# Compare weights (using np.isclose for float comparison)
if not in_weight == sum(out_weights):
raise ValueError(
f"Weight mismatch at node {node}: "
f"incoming weight {in_weight} != sum of outgoing weights {sum(out_weights)}"
)
return True
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
def texts_to_hashes(
tokenizer: PreTrainedTokenizerBase, texts: List[str], block_size: int = 512
) -> List[List[int]]:
"""
Tokenizes a list of strings (without special tokens), splits tokens into blocks,
computes rolling hashes, and returns a list of lists of integer-mapped rolling hashes
for each input string.
Args:
tokenizer: Tokenizer object with a .encode method.
texts (List[str]): List of input strings.
block_size (int): Size of each token block for hashing.
Returns:
List[List[int]]: List of lists of integer-mapped rolling hashes for each block of each input string.
"""
# Batch tokenize for efficiency
batch_encoding = tokenizer(
texts,
add_special_tokens=False,
return_attention_mask=False,
return_token_type_ids=False,
)
# batch_encoding["input_ids"] is a List[List[int]]
all_tokens: List[List[int]] = batch_encoding["input_ids"]
results: List[List[int]] = []
hash_to_int: Dict[int, int] = {}
next_int = 0
for tokens in all_tokens:
blocks: List[List[int]] = [
tokens[i : i + block_size] for i in range(0, len(tokens), block_size)
]
parent_hash = 0
hashes: List[int] = []
print(blocks)
for block in blocks:
combined = (parent_hash, hash(tuple(block)))
global_hash = hash(combined)
# Map global_hash to a unique integer
if global_hash not in hash_to_int:
hash_to_int[global_hash] = next_int
next_int += 1
hashes.append(hash_to_int[global_hash])
parent_hash = global_hash
results.append(hashes)
return results
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
import numpy as np
import pandas as pd
from tabulate import tabulate
def calculate_and_print_statistics(metrics: Dict[str, List[float]]) -> pd.DataFrame:
"""
Calculate statistics for a dictionary of metrics and print them in a tabular format.
Args:
metrics: Dictionary where keys are metric names and values are lists of metric values
Returns:
pandas.DataFrame: DataFrame containing the calculated statistics
"""
metric_names = []
stats_data = []
# Calculate statistics for each metric
for metric_name, values in metrics.items():
metric_names.append(metric_name)
stats_data.append(
{
"Mean": np.mean(values),
"Std Dev": np.std(values),
"Min": np.min(values),
"P25": np.percentile(values, 25),
"Median": np.median(values),
"P75": np.percentile(values, 75),
"Max": np.max(values),
}
)
# Replace the printing code with tabulate
stats_df = pd.DataFrame(stats_data, index=metric_names)
print(tabulate(stats_df, headers="keys", tablefmt="pretty", floatfmt=".2f"), "\n")
return stats_df
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from collections import Counter
from data_generator.logging import calculate_and_print_statistics
class PrefixAnalyzer:
"""
A class for analyzing dataset characteristics related to prefixes, hash IDs, and cache hit rates.
"""
def __init__(self, dataset_path, block_size=1):
"""
Initialize the analyzer with dataset path and block size.
Args:
dataset_path: Path to the JSONL dataset file
block_size: Size of each block for prefix calculation
"""
self.dataset_path = dataset_path
self.block_size = block_size
self.dataset = self._load_dataset()
self.hash_counter = self._build_hash_counter()
self.repeated_hash_ids = {
hash_id for hash_id, count in self.hash_counter.items() if count > 1
}
def _load_dataset(self) -> list:
print(f"Loading dataset from {self.dataset_path}...")
dataset = []
with open(self.dataset_path, "r") as f:
for line in f:
dataset.append(json.loads(line))
print(f"Dataset loaded: {len(dataset)} examples")
return dataset
def _build_hash_counter(self) -> Counter:
all_hash_ids = []
for item in self.dataset:
all_hash_ids.extend(item["hash_ids"])
counter = Counter(all_hash_ids)
print(f"Hash counter built: {len(counter)} unique hash IDs")
return counter
def analyze(self) -> dict[str, list]:
"""
Analyze dataset to extract various length metrics and print statistics.
Returns:
Tuple of lists: (input_lengths, prefix_lengths, user_prompt_lengths, output_lengths)
"""
# Extract input and output lengths directly from fields
input_lengths = [item["input_length"] for item in self.dataset]
output_lengths = [item["output_length"] for item in self.dataset]
# Calculate prefix length and user prompt length for each row
prefix_lengths = []
user_prompt_lengths = []
for i, item in enumerate(self.dataset):
input_len = item["input_length"]
hash_ids = item["hash_ids"]
assert len(hash_ids) * self.block_size >= input_len
# Special case: if all hash IDs in the row are repeated elsewhere
if all(hash_id in self.repeated_hash_ids for hash_id in hash_ids):
prefix_len = input_len # Set prefix length to input length
user_prompt_len = 0 # Set user prompt length to 0
else:
# Count how many hash IDs in this row are repeated elsewhere in the dataset
repeated_count = sum(
1 for hash_id in hash_ids if hash_id in self.repeated_hash_ids
)
prefix_len = repeated_count * self.block_size
user_prompt_len = input_len - prefix_len
prefix_lengths.append(prefix_len)
user_prompt_lengths.append(user_prompt_len)
# Check if prefix length is greater than input length
if prefix_len > input_len:
print(f"WARNING: Line {i}: {json.dumps(item)}")
cache_hit_rates = self._analyze_cache_hit_rates()
# Print statistics table
metrics = {
"Input Length": input_lengths,
"Context Length": prefix_lengths,
"Unique Prompt Length": user_prompt_lengths,
"Output Length": output_lengths,
"Theoretical Hit Rates": cache_hit_rates,
}
calculate_and_print_statistics(metrics)
return metrics
def _analyze_cache_hit_rates(self) -> list[float]:
"""
Analyze theoretical cache hit rates based on hash ID repetition.
Assumes that hash IDs are cached as the dataset is iterated through,
i.e., each hash ID is considered "cached" after its first appearance,
similar to how KV caching would work in real life.
Assumes the cache is infinite in size (hence "theoretical"), so no hash IDs are ever evicted.
Returns:
List of cache hit rates for each row in the dataset
"""
# Set to track all hash IDs we've seen
seen_hash_ids = set()
# Store cache hit rates for each row
cache_hit_rates = []
for item in self.dataset:
hash_ids = item["hash_ids"]
# Skip if there are no hash IDs
if len(hash_ids) == 0:
continue
# Find the first index where the hash ID hasn't been seen before
first_unseen_idx = len(hash_ids) # Default if all are seen
for idx, hash_id in enumerate(hash_ids):
if hash_id not in seen_hash_ids:
first_unseen_idx = idx
break
# Calculate cache hit rate
cache_hit_rate = first_unseen_idx / len(hash_ids)
cache_hit_rates.append(cache_hit_rate)
# Add all hash IDs from this row to the seen set
seen_hash_ids.update(hash_ids)
return cache_hit_rates
def main():
import argparse
parser = argparse.ArgumentParser(description="Analyze prefix dataset statistics")
parser.add_argument(
"--input-file",
type=str,
default="mooncake_trace.jsonl",
help="Path to the input dataset file (default: mooncake_trace.jsonl)",
)
parser.add_argument(
"--block-size",
type=int,
default=512,
help="Block size for prefix calculation (default: 512)",
)
args = parser.parse_args()
block_size = args.block_size
dataset_path = args.input_file
print(f"Analyzing dataset: {dataset_path}")
print(f"Using block size: {block_size}")
print()
# Create analyzer instance
analyzer = PrefixAnalyzer(dataset_path, block_size=block_size)
analyzer.analyze()
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Protocol-level constants for synthetic data graph structure.
"""
SUPER_ROOT = -1 # Dummy node preceding all real nodes; not an actual data root
CACHE_END = -2 # Special node indicating end of a path
END_NODE = -3 # Special node indicating to skip leaf sampling
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from numpy.random import Generator
logger = logging.getLogger(__name__)
def get_cdf(weights: List[float]) -> np.ndarray:
cumsum = np.cumsum(weights)
return cumsum / cumsum[-1]
def data_to_cdf(data: np.ndarray) -> Tuple[List[Any], np.ndarray]:
sorted_counter: Dict[Any, int] = dict(sorted(Counter(data).items()))
data_unique: List[Any] = list(sorted_counter.keys())
counter_cdf: np.ndarray = get_cdf(list(sorted_counter.values()))
return data_unique, counter_cdf
def sample_from_cdf(
data: List[Any], cdf: np.ndarray, rng: Optional[Generator] = None
) -> Any:
# NOTE: assumes (but does not verify) that the CDF is valid
# CDF stands for cumulative distribution function
assert len(data) == len(cdf)
if rng is not None:
return data[np.searchsorted(cdf, rng.random())]
else:
return data[np.searchsorted(cdf, np.random.rand())]
class EmpiricalSampler:
"""
Takes data, learns from the pure empirical distribution, and samples directly from it.
Args:
data (Union[List[Any], np.ndarray]): The input data to learn the distribution from.
"""
def __init__(self, data: Union[List[Any], np.ndarray]) -> None:
self.rng = np.random.default_rng(0)
self.empty_data = len(data) == 0
if self.empty_data:
logger.warning("Empty data provided to EmpiricalSampler")
else:
self.data, self.cdf = data_to_cdf(np.array(data))
def sample(self) -> Any:
if self.empty_data:
return 0
return sample_from_cdf(self.data, self.cdf, self.rng)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from collections import Counter
from typing import Any, Optional
import networkx as nx
import numpy as np
import pandas as pd
from data_generator.graph_utils import (
_merge_chains,
_precompute_transition_cdfs,
_remove_leaves,
)
from data_generator.protocols import CACHE_END, END_NODE, SUPER_ROOT
from data_generator.sampler import EmpiricalSampler, sample_from_cdf
class Synthesizer:
def __init__(
self,
dataset_file: str,
block_size: int = 512,
num_copies: int = 1,
speedup_ratio: float = 1.0,
prefix_len_multiplier: float = 1.0,
prompt_len_multiplier: float = 1.0,
):
"""Load the mooncake dataset and extract core statistics like
radix-tree structure, ISL, OSL, and request timings.
Generate synthetic datasets based on these statistics, with tunable knobs,
e.g. to increase request rate or the ISL.
A request is broken into two parts: a context and a prompt. A context is
any block that is (can possibly be) visited more than once, while a prompt
is considered to be unique and only visited once (user prompt).
Args:
dataset_file (str): The mooncake trace file in jsonl format.
block_size (int, optional): The block size for prefilling and decoding.
Defaults to 512.
speedup_ratio (int, optional): For speeding up the request intervals.
Defaults to 1.
context_len_multiplier (float, optional): For every node in the core radix-tree,
increase the substring length by this multiplier, and rounded to the nearest
multiple of the block size. In other words, shared prefix prompts will be
expanded by this factor. Defaults to 1.
num_copies (int, optional): Number of times to replicate the core radix tree.
Defaults to 1.
prompt_len_multiplier (float, optional): Multiplies the leaf path lengths by this factor
(rounded to integers). Use values < 1 to generate shorter prompts. Defaults to 1.
Note this does not affect the lengths of the core context prompts.
NOTE: currently may only work for the mooncake trace file,
as it assumes consecutive integers
NOTE: If the context_len_multiplier is not one, then the synthetic data
cannot be mixed and matched with the original trace file,
as the hash ids will be relabeled.
"""
self.block_size = block_size
self.num_copies = num_copies
self.speedup_ratio = float(speedup_ratio)
self.prefix_len_multiplier = float(prefix_len_multiplier)
self.prompt_len_multiplier = float(prompt_len_multiplier)
# assert correct arg bounds
assert (
isinstance(self.num_copies, int) and self.num_copies >= 1
), "num_copies must be an integer greater than or equal to 1"
assert (
isinstance(self.speedup_ratio, float) and self.speedup_ratio > 0
), "speedup_ratio must be a positive float"
assert (
isinstance(self.prefix_len_multiplier, float)
and self.prefix_len_multiplier > 0
), "context_len_multiplier must be a positive float"
assert (
isinstance(self.prompt_len_multiplier, float)
and self.prompt_len_multiplier > 0
), "prompt_len_multiplier must be a positive float"
# extract data from json file
with open(dataset_file, "r") as f:
hash_ids_list = []
timestamps = []
input_lens = []
output_lens = []
for line in f:
data = json.loads(line)
hash_ids_list.append(np.array(data["hash_ids"]))
timestamps.append(int(data["timestamp"]))
input_lens.append(np.array(data["input_length"]))
output_lens.append(int(data["output_length"]))
# represent prefix-tree as directed graph
self.G = nx.DiGraph()
max_hash_id = SUPER_ROOT
num_paths = 0
self.G.add_node(-1, end=0)
for hash_ids in hash_ids_list:
num_paths += 1
for i in range(len(hash_ids)):
u = hash_ids[i - 1] if i > 0 else SUPER_ROOT
v = hash_ids[i]
max_hash_id = max(v, max_hash_id)
if v in self.G:
self.G.nodes[v]["visited"] += 1
else:
self.G.add_node(v, visited=1, end=0)
if self.G.has_edge(u, v):
self.G[u][v]["weight"] += 1
else:
self.G.add_edge(u, v, weight=1)
self.G.nodes[v]["end"] += 1
self.G.nodes[SUPER_ROOT]["visited"] = num_paths
self.max_hash_id = max_hash_id
invalid_nodes = [(node, d) for node, d in self.G.in_degree() if d > 1]
if invalid_nodes:
print("ERROR: The following nodes have multiple parents (in-degree > 1):")
for node, in_degree in invalid_nodes:
parents = list(self.G.predecessors(node))
print(f" Node {node}: in-degree={in_degree}, parents={parents}")
raise ValueError(
"Graph is not a valid tree: nodes with multiple parents detected"
)
# visits to leaf nodes (non-core branches) are considered as ended
for node in self.G.nodes():
if "to_leaf" not in self.G.nodes[node]:
self.G.nodes[node]["to_leaf"] = 0
if self.G.nodes[node]["visited"] <= 1:
continue
for child in self.G.successors(node):
if self.G.nodes[child]["visited"] == 1:
self.G.nodes[node]["to_leaf"] += 1
# make graph radix-like
self.G = _merge_chains(self.G)
self.G, leaves_lens = _remove_leaves(self.G)
# Apply prompt_len_multiplier to leaves_lens
if self.prompt_len_multiplier != 1:
leaves_lens = [
max(1, round(length * self.prompt_len_multiplier))
for length in leaves_lens
]
self.leaves_lens_sampler = EmpiricalSampler(leaves_lens)
self._relabel_nodes()
self.G = _precompute_transition_cdfs(self.G)
# get statistics of timing, request counts, ISL, and OSL
request_counts = list(Counter(timestamps).values())
self.request_counts_sampler = EmpiricalSampler(request_counts)
timedeltas = np.diff(timestamps)
timedeltas = timedeltas[timedeltas > 0]
self.timedeltas_sampler = EmpiricalSampler(timedeltas)
input_lens_mod = np.array(
[
input_len - (len(hash_ids) - 1) * block_size
for input_len, hash_ids in zip(input_lens, hash_ids_list)
]
)
assert np.all(0 < input_lens_mod) and np.all(input_lens_mod <= self.block_size)
self.input_lens_mod_sampler = EmpiricalSampler(input_lens_mod)
self.output_lens_sampler = EmpiricalSampler(output_lens)
def _relabel_nodes(self) -> None:
# Scale node labels by length multiplier if needed
if self.prefix_len_multiplier > 1:
multiplier = int(np.ceil(self.prefix_len_multiplier))
# Create mapping for relabeling, preserving -1 and -2
mapping = {
node: (node if node < 0 else node * multiplier + multiplier)
for node in self.G.nodes()
}
self.G = nx.relabel_nodes(self.G, mapping)
# Update max_hash_id
self.max_hash_id = multiplier * self.max_hash_id + multiplier
# Shrink the lengths, but no need to relabel nodes
elif self.prefix_len_multiplier < 1:
for node in self.G.nodes():
self.G.nodes[node]["length"] = max(
round(self.G.nodes[node]["length"] * self.prefix_len_multiplier), 1
)
def _synthesize_leaf_path(self) -> list[int]:
# Sample the leaf path length
leaf_length = self.leaves_lens_sampler.sample()
# Generate new nodes starting from max_hash_id + 1
path = [int(self.max_hash_id + 1 + i) for i in range(leaf_length)]
# Update max_hash_id
self.max_hash_id += leaf_length
return path
def synthesize_path(self) -> tuple[list[int], bool, int]:
"""
Synthesizes a path through the core radix tree, optionally appending a unique user prompt (leaf path).
Returns:
tuple:
- list[int]: The full path as a list of hash_ids. This consists of the cached (core) hash_ids,
with new unique hash_ids appended at the end if a leaf path is included.
- bool: Whether the path contains a leaf path (i.e., new unique hash_ids were appended).
- int: The context length, defined as the number of cached hash_ids multiplied by block_size.
"""
# Start from root node (-1)
current_node = SUPER_ROOT
path: list[int] = []
context_len = 0
# Continue until we reach a node with no outgoing edges
while True:
# Use precomputed CDFs for efficient sampling
next_node = sample_from_cdf(
self.G.nodes[current_node]["out_nodes"],
self.G.nodes[current_node]["out_cdf"],
)
# end early
# break and start sampling unique user prompt
if next_node == CACHE_END:
break
# break and don't sample leaf
if next_node == END_NODE:
return path, False, context_len
# otherwise continue down prefix tree
# Get the length of the contracted path
length = self.G.nodes[next_node]["length"]
context_len += length * self.block_size
# Add all intermediate nodes
for i in range(length):
path.append(int(next_node - (length - 1) + i))
current_node = next_node
unique_user_prompt = self._synthesize_leaf_path()
# Append a leaf path at the end
return path + unique_user_prompt, True, context_len
def synthesize_requests(
self, num_requests: int, input_len_filter: Optional[int] = None
) -> list[dict[str, Any]]:
timestamp = 0
requests: list[dict[str, Any]] = []
request_id = 0
while request_id < num_requests:
requests_per_interval = self.request_counts_sampler.sample()
for _ in range(requests_per_interval):
path, leaf_flag, context_len = self.synthesize_path()
if leaf_flag:
input_len = (
len(path) - 1
) * self.block_size + self.input_lens_mod_sampler.sample()
else:
input_len = len(path) * self.block_size
output_len = self.output_lens_sampler.sample()
if input_len_filter is not None and input_len > input_len_filter:
continue
requests.append(
{
"timestamp": int(timestamp),
"input_length": int(input_len),
"output_length": int(output_len),
"hash_ids": path,
"context_len": int(context_len),
"unique_user_prompt_len": int(input_len - context_len),
}
)
request_id += 1
if request_id >= num_requests:
break
timestamp += round(self.timedeltas_sampler.sample() / self.speedup_ratio)
# Adjust hash_ids if num_copies > 1
if self.num_copies > 1:
for request in requests:
offset = (np.random.randint(0, self.num_copies)) * (
self.max_hash_id + 1
)
request["hash_ids"] = [
int(hash_id + offset) for hash_id in request["hash_ids"]
]
return requests
def __repr__(self) -> str:
path_lengths = nx.single_source_shortest_path_length(self.G, -1)
core_radix_tree_size = len(self.G) - 1
core_radix_tree_depth = max(path_lengths.values()) if path_lengths else 0
rep = "MooncakeSynth("
rep += f"core_radix_tree_size={core_radix_tree_size}, "
rep += f"core_radix_tree_depth={core_radix_tree_depth}, "
rep += f"block_size={self.block_size})"
children = list(self.G.successors(-1))
data = {
"Child Node": children,
"Visited Count": [self.G.nodes[child]["visited"] for child in children],
"Length": [self.G.nodes[child].get("length", "N/A") for child in children],
}
df = pd.DataFrame(data)
df = df[df["Visited Count"] >= 5]
df = df.sort_values("Visited Count", ascending=False)
grouped = df.groupby("Length", sort=True)
rep += "\nRoot nodes (grouped by length, visited count ≥ 5):\n"
for length, group in grouped:
nodes = group["Child Node"].tolist()
visit_counts = group["Visited Count"].tolist()
rep += f"\nNodes: {nodes}, Path Length: {length}, Visited Counts: {visit_counts}"
return rep
def main():
import argparse
from pathlib import Path
from data_generator.logging import calculate_and_print_statistics
parser = argparse.ArgumentParser(description="Synthesize Mooncake-Esque dataset")
parser.add_argument(
"--input-file",
default="mooncake_trace.jsonl",
type=str,
help="Path to the input CSV file",
)
parser.add_argument(
"--num-requests",
type=int,
default=int(1e5),
help="Number of requests to synthesize (default: 100000)",
)
parser.add_argument(
"--speedup-ratio",
type=float,
default=1,
help="Factor to speed up request intervals (default: 1)",
)
parser.add_argument(
"--prefix-len-multiplier",
type=float,
default=1.0,
help="Multiplier for prefix lengths (default: 1.0)",
)
parser.add_argument(
"--prefix-root-multiplier",
type=int,
default=1,
help="Number of times to replicate the core radix tree (default: 1)",
)
parser.add_argument(
"--prompt-len-multiplier",
type=float,
default=1.0,
help="Multiplier for leaf path lengths (default: 1.0, use <1 for shorter prompts)",
)
parser.add_argument(
"--max-isl",
type=int,
default=None,
help="Maximum input sequence length to include in output (default: None, no filtering)",
)
parser.add_argument(
"--block-size",
type=int,
default=512,
help="Block size for prefilling and decoding (default: 512)",
)
parser.add_argument(
"--output-file",
type=str,
default=None,
help="Path to the output file (default: None, no output)",
)
args = parser.parse_args()
dataset_file = Path(args.input_file).resolve()
if args.output_file is None:
output_file = dataset_file.with_stem(
f"{dataset_file.stem}_synth"
+ f"_{int(args.prefix_len_multiplier)}x{args.prefix_root_multiplier}+{args.prompt_len_multiplier}"
+ f"_speedup{args.speedup_ratio}"
+ f"_maxisl{args.max_isl}"
)
else:
output_file = Path(args.output_file).resolve()
print("learning from dataset...", flush=True)
synthesizer = Synthesizer(
str(dataset_file),
block_size=args.block_size,
speedup_ratio=args.speedup_ratio,
prefix_len_multiplier=args.prefix_len_multiplier,
num_copies=args.prefix_root_multiplier,
prompt_len_multiplier=args.prompt_len_multiplier,
)
print("synthesizing requests...", flush=True)
requests = synthesizer.synthesize_requests(args.num_requests, args.max_isl)
print(f"synthesized {len(requests)} requests")
# Print statistics in a single table with metrics as rows and statistics as columns
print("\n###### Synthesized Statistics ######")
# Extract all values first
metrics = {
"Input Length": [req["input_length"] for req in requests],
"Context Length": [req["context_len"] for req in requests],
"Unique Prompt Length": [req["unique_user_prompt_len"] for req in requests],
"Output Length": [req["output_length"] for req in requests],
}
# Calculate statistics for each metric
calculate_and_print_statistics(metrics)
with open(output_file, "w") as f:
for request in requests:
f.write(json.dumps(request) + "\n")
print(f"synthetic dataset saved at {Path(output_file).resolve()}")
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from data_generator.hasher import texts_to_hashes
from tokenizers import Tokenizer, decoders, models, normalizers, pre_tokenizers
from transformers import PreTrainedTokenizerFast
@pytest.fixture(scope="module")
def dummy_tokenizer():
vocab = [chr(i) for i in range(ord("a"), ord("z") + 1)]
vocab.append("[UNK]")
vocab_dict = {token: idx for idx, token in enumerate(vocab)}
tokenizer_model = models.WordLevel(vocab=vocab_dict, unk_token="[UNK]")
tokenizer = Tokenizer(tokenizer_model)
tokenizer.normalizer = normalizers.Sequence(
[normalizers.NFD(), normalizers.Lowercase()]
)
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
tokenizer.decoder = decoders.WordPiece(prefix="")
return PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
pad_token="[PAD]",
bos_token="[BOS]",
eos_token="[EOS]",
)
def test_texts_to_hashes_blocks(dummy_tokenizer):
dum1 = "a b c d"
dum2 = "e f g h"
dum3 = "i j k l"
texts = [dum1, dum1 + " " + dum2, dum1 + " " + dum3, dum2 + " " + dum1]
expected = [[0], [0, 1], [0, 2], [3, 4]]
result = texts_to_hashes(dummy_tokenizer, texts, block_size=4)
assert result == expected, f"Expected {expected}, got {result}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Counter
import numpy as np
from data_generator.sampler import EmpiricalSampler
def test_empirical_sampler_distribution():
# Create a test array with equal numbers of 1, 2, and 3
test_data = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
# Create the sampler
sampler = EmpiricalSampler(test_data)
# Sample 1000 times
samples = [sampler.sample() for _ in range(1000)]
# Count occurrences of each value
counts = Counter(samples)
# Verify each number (1, 2, 3) appears between 300 and 400 times
for value in [1, 2, 3]:
assert (
300 <= counts[value] <= 400
), f"Value {value} appeared {counts[value]} times, expected 300-400 times"
# Verify no other values appear in the samples
assert set(counts.keys()) == {
1,
2,
3,
}, f"Unexpected values in samples: {set(counts.keys()) - {1, 2, 3}}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import random
import tempfile
import unittest
from data_generator.synthesizer import Synthesizer
# Helper function to create and dump data
def dump_record(handle, hash_ids, block_size=512):
input_length = block_size * len(hash_ids)
output_length = random.randint(50, 250)
data = {
"timestamp": 1000,
"hash_ids": hash_ids,
"input_length": input_length,
"output_length": output_length,
}
json.dump(data, handle)
handle.write("\n")
def check_attributes(
graph,
node,
expected_children,
expected_visited=None,
expected_length=None,
expected_to_leaf=None,
):
# Check children
actual_children = list(graph.successors(node))
assert sorted(actual_children) == sorted(
expected_children
), f"Node {node} has children {actual_children}, expected {expected_children}"
# Check 'visited' attribute if expected
if expected_visited is not None:
assert (
graph.nodes[node].get("visited") == expected_visited
), f"Node {node} has 'visited' value {graph.nodes[node].get('visited')}, expected {expected_visited}"
# Check 'length' attribute if expected
if expected_length is not None:
assert (
graph.nodes[node].get("length") == expected_length
), f"Node {node} has 'length' value {graph.nodes[node].get('length')}, expected {expected_length}"
# Check 'to_leaf' attribute if expected
if expected_to_leaf is not None:
assert (
graph.nodes[node].get("to_leaf") == expected_to_leaf
), f"Node {node} has 'to_leaf' value {graph.nodes[node].get('to_leaf')}, expected {expected_to_leaf}"
return True
def test_graph_structure():
# Create a temporary JSONL file with the specified data
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as tmp:
dump_record(tmp, [0, 1])
dump_record(tmp, [0, 1, 2, 3, 4])
dump_record(tmp, [0, 1, 2, 3, 4, 5, 6])
dump_record(tmp, [7, 8])
dump_record(tmp, [7, 8, 9, 10])
dump_record(tmp, [11, 12])
# Create the Synthesizer with the temporary file
synthesizer = Synthesizer(tmp.name, block_size=512)
G = synthesizer.G
# Verify the graph structure
check_attributes(G, -1, [1, 8], 6, None, 1)
check_attributes(G, 1, [4], 3, 2, 0)
check_attributes(G, 4, [], 2, 3, 1)
check_attributes(G, 8, [], 2, 2, 1)
# Clean up
os.unlink(tmp.name)
if __name__ == "__main__":
unittest.main()
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[project]
name = "data-generator"
version = "0.1.0"
description = "Data generator library for LLM benchmarks"
readme = "README.md"
authors = [
{name = "NVIDIA CORPORATION & AFFILIATES"}
]
license = {text = "Apache-2.0"}
requires-python = ">=3.10"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Intended Audience :: Information Technology",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Operating System :: POSIX :: Linux",
]
dependencies = [
"networkx",
"pandas",
"tabulate",
"types-tabulate",
"transformers",
"pytest-mypy",
]
[project.scripts]
datagen = "data_generator.cli:main"
[project.urls]
Repository = "https://github.com/ai-dynamo/dynamo.git"
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
packages = ["data_generator"]
[tool.setuptools.package-data]
data_generator = ["**/*.py"]
[tool.mypy]
explicit_package_bases = true
ignore_missing_imports = true
check_untyped_defs = true
[tool.pytest.ini_options]
addopts = [
"-ra",
"--showlocals",
"--strict-markers",
"--strict-config",
"--mypy", # This flag enables mypy type checking during pytest runs
"--ignore-glob=*model.py",
"--ignore-glob=*_inc.py",
"--ignore-glob=deploy/cloud/api-store/*",
]
\ No newline at end of file
...@@ -65,6 +65,8 @@ Such saturation can create a feedback loop—where the cache-rich worker continu ...@@ -65,6 +65,8 @@ Such saturation can create a feedback loop—where the cache-rich worker continu
## Tuning Guidelines ## Tuning Guidelines
Currently, optimal use of our KV router requires understanding your backend engine's capacity and the prefix structure of your data. We provide analysis tools for this purpose in the `benchmarks` directory. In the future, we plan to enable automatic tuning of our KV router (via `Planner`) using worker feedback metrics and dynamic analysis of data prefix structures (WIP). Below are several tips we recommend following.
### 1. Consider Total KV Block Allocation ### 1. Consider Total KV Block Allocation
Check the total number of KV blocks allocated for your backend engine. For smaller models (e.g., 8B parameters), this can exceed one million blocks. In such cases: Check the total number of KV blocks allocated for your backend engine. For smaller models (e.g., 8B parameters), this can exceed one million blocks. In such cases:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment