# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json from collections import Counter from data_generator.logging_utils import calculate_and_print_statistics class PrefixAnalyzer: """ A class for analyzing dataset characteristics related to prefixes, hash IDs, and cache hit rates. """ def __init__(self, dataset_path, block_size=1): """ Initialize the analyzer with dataset path and block size. Args: dataset_path: Path to the JSONL dataset file block_size: Size of each block for prefix calculation """ self.dataset_path = dataset_path self.block_size = block_size self.dataset = self._load_dataset() self.hash_counter = self._build_hash_counter() self.repeated_hash_ids = { hash_id for hash_id, count in self.hash_counter.items() if count > 1 } def _load_dataset(self) -> list: print(f"Loading dataset from {self.dataset_path}...") dataset = [] with open(self.dataset_path, "r") as f: for line in f: dataset.append(json.loads(line)) print(f"Dataset loaded: {len(dataset)} examples") return dataset def _build_hash_counter(self) -> Counter: all_hash_ids = [] for item in self.dataset: all_hash_ids.extend(item["hash_ids"]) counter = Counter(all_hash_ids) print(f"Hash counter built: {len(counter)} unique hash IDs") return counter def analyze(self) -> dict[str, list]: """ Analyze dataset to extract various length metrics and print statistics. Returns: Tuple of lists: (input_lengths, prefix_lengths, user_prompt_lengths, output_lengths) """ # Extract input and output lengths directly from fields input_lengths = [item["input_length"] for item in self.dataset] output_lengths = [item["output_length"] for item in self.dataset] # Calculate prefix length and user prompt length for each row prefix_lengths = [] user_prompt_lengths = [] for i, item in enumerate(self.dataset): input_len = item["input_length"] hash_ids = item["hash_ids"] assert len(hash_ids) * self.block_size >= input_len # Special case: if all hash IDs in the row are repeated elsewhere if all(hash_id in self.repeated_hash_ids for hash_id in hash_ids): prefix_len = input_len # Set prefix length to input length user_prompt_len = 0 # Set user prompt length to 0 else: # Count how many hash IDs in this row are repeated elsewhere in the dataset repeated_count = sum( 1 for hash_id in hash_ids if hash_id in self.repeated_hash_ids ) prefix_len = repeated_count * self.block_size user_prompt_len = input_len - prefix_len prefix_lengths.append(prefix_len) user_prompt_lengths.append(user_prompt_len) # Check if prefix length is greater than input length if prefix_len > input_len: print(f"WARNING: Line {i}: {json.dumps(item)}") cache_hit_rates = self._analyze_cache_hit_rates() # Print statistics table metrics = { "Input Length": input_lengths, "Context Length": prefix_lengths, "Unique Prompt Length": user_prompt_lengths, "Output Length": output_lengths, "Theoretical Hit Rates": cache_hit_rates, } calculate_and_print_statistics(metrics) return metrics def _analyze_cache_hit_rates(self) -> list[float]: """ Analyze theoretical cache hit rates based on hash ID repetition. Assumes that hash IDs are cached as the dataset is iterated through, i.e., each hash ID is considered "cached" after its first appearance, similar to how KV caching would work in real life. Assumes the cache is infinite in size (hence "theoretical"), so no hash IDs are ever evicted. Returns: List of cache hit rates for each row in the dataset """ # Set to track all hash IDs we've seen seen_hash_ids = set() # Store cache hit rates for each row cache_hit_rates = [] for item in self.dataset: hash_ids = item["hash_ids"] # Skip if there are no hash IDs if len(hash_ids) == 0: continue # Find the first index where the hash ID hasn't been seen before first_unseen_idx = len(hash_ids) # Default if all are seen for idx, hash_id in enumerate(hash_ids): if hash_id not in seen_hash_ids: first_unseen_idx = idx break # Calculate cache hit rate cache_hit_rate = first_unseen_idx / len(hash_ids) cache_hit_rates.append(cache_hit_rate) # Add all hash IDs from this row to the seen set seen_hash_ids.update(hash_ids) return cache_hit_rates def main(): import argparse parser = argparse.ArgumentParser(description="Analyze prefix dataset statistics") parser.add_argument( "--input-file", type=str, default="mooncake_trace.jsonl", help="Path to the input dataset file (default: mooncake_trace.jsonl)", ) parser.add_argument( "--block-size", type=int, default=512, help="Block size for prefix calculation (default: 512)", ) args = parser.parse_args() block_size = args.block_size dataset_path = args.input_file print(f"Analyzing dataset: {dataset_path}") print(f"Using block size: {block_size}") print() # Create analyzer instance analyzer = PrefixAnalyzer(dataset_path, block_size=block_size) analyzer.analyze() if __name__ == "__main__": main()