prefix_analyzer.py 6.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from collections import Counter

19
from data_generator.logging_utils import calculate_and_print_statistics
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187


class PrefixAnalyzer:
    """
    A class for analyzing dataset characteristics related to prefixes, hash IDs, and cache hit rates.
    """

    def __init__(self, dataset_path, block_size=1):
        """
        Initialize the analyzer with dataset path and block size.

        Args:
            dataset_path: Path to the JSONL dataset file
            block_size: Size of each block for prefix calculation
        """
        self.dataset_path = dataset_path
        self.block_size = block_size
        self.dataset = self._load_dataset()
        self.hash_counter = self._build_hash_counter()
        self.repeated_hash_ids = {
            hash_id for hash_id, count in self.hash_counter.items() if count > 1
        }

    def _load_dataset(self) -> list:
        print(f"Loading dataset from {self.dataset_path}...")
        dataset = []
        with open(self.dataset_path, "r") as f:
            for line in f:
                dataset.append(json.loads(line))
        print(f"Dataset loaded: {len(dataset)} examples")
        return dataset

    def _build_hash_counter(self) -> Counter:
        all_hash_ids = []
        for item in self.dataset:
            all_hash_ids.extend(item["hash_ids"])
        counter = Counter(all_hash_ids)
        print(f"Hash counter built: {len(counter)} unique hash IDs")
        return counter

    def analyze(self) -> dict[str, list]:
        """
        Analyze dataset to extract various length metrics and print statistics.

        Returns:
            Tuple of lists: (input_lengths, prefix_lengths, user_prompt_lengths, output_lengths)
        """
        # Extract input and output lengths directly from fields
        input_lengths = [item["input_length"] for item in self.dataset]
        output_lengths = [item["output_length"] for item in self.dataset]

        # Calculate prefix length and user prompt length for each row
        prefix_lengths = []
        user_prompt_lengths = []

        for i, item in enumerate(self.dataset):
            input_len = item["input_length"]
            hash_ids = item["hash_ids"]
            assert len(hash_ids) * self.block_size >= input_len

            # Special case: if all hash IDs in the row are repeated elsewhere
            if all(hash_id in self.repeated_hash_ids for hash_id in hash_ids):
                prefix_len = input_len  # Set prefix length to input length
                user_prompt_len = 0  # Set user prompt length to 0
            else:
                # Count how many hash IDs in this row are repeated elsewhere in the dataset
                repeated_count = sum(
                    1 for hash_id in hash_ids if hash_id in self.repeated_hash_ids
                )
                prefix_len = repeated_count * self.block_size
                user_prompt_len = input_len - prefix_len

            prefix_lengths.append(prefix_len)
            user_prompt_lengths.append(user_prompt_len)

            # Check if prefix length is greater than input length
            if prefix_len > input_len:
                print(f"WARNING: Line {i}: {json.dumps(item)}")

        cache_hit_rates = self._analyze_cache_hit_rates()

        # Print statistics table
        metrics = {
            "Input Length": input_lengths,
            "Context Length": prefix_lengths,
            "Unique Prompt Length": user_prompt_lengths,
            "Output Length": output_lengths,
            "Theoretical Hit Rates": cache_hit_rates,
        }

        calculate_and_print_statistics(metrics)

        return metrics

    def _analyze_cache_hit_rates(self) -> list[float]:
        """
        Analyze theoretical cache hit rates based on hash ID repetition.

        Assumes that hash IDs are cached as the dataset is iterated through,
        i.e., each hash ID is considered "cached" after its first appearance,
        similar to how KV caching would work in real life.
        Assumes the cache is infinite in size (hence "theoretical"), so no hash IDs are ever evicted.

        Returns:
            List of cache hit rates for each row in the dataset
        """
        # Set to track all hash IDs we've seen
        seen_hash_ids = set()

        # Store cache hit rates for each row
        cache_hit_rates = []

        for item in self.dataset:
            hash_ids = item["hash_ids"]

            # Skip if there are no hash IDs
            if len(hash_ids) == 0:
                continue

            # Find the first index where the hash ID hasn't been seen before
            first_unseen_idx = len(hash_ids)  # Default if all are seen
            for idx, hash_id in enumerate(hash_ids):
                if hash_id not in seen_hash_ids:
                    first_unseen_idx = idx
                    break

            # Calculate cache hit rate
            cache_hit_rate = first_unseen_idx / len(hash_ids)
            cache_hit_rates.append(cache_hit_rate)

            # Add all hash IDs from this row to the seen set
            seen_hash_ids.update(hash_ids)

        return cache_hit_rates


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Analyze prefix dataset statistics")
    parser.add_argument(
        "--input-file",
        type=str,
        default="mooncake_trace.jsonl",
        help="Path to the input dataset file (default: mooncake_trace.jsonl)",
    )
    parser.add_argument(
        "--block-size",
        type=int,
        default=512,
        help="Block size for prefix calculation (default: 512)",
    )
    args = parser.parse_args()

    block_size = args.block_size
    dataset_path = args.input_file

    print(f"Analyzing dataset: {dataset_path}")
    print(f"Using block size: {block_size}")
    print()

    # Create analyzer instance
    analyzer = PrefixAnalyzer(dataset_path, block_size=block_size)
    analyzer.analyze()


if __name__ == "__main__":
    main()