Unverified Commit ec438f8c authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: minor robustness improvement for datagen analyze (#3483)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 0a5df130
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
import random import random
import pandas as pd import pandas as pd
from prefix_data_generator.hasher import RollingHasher
from tqdm import tqdm from tqdm import tqdm
...@@ -154,8 +155,9 @@ def convert_to_mooncake(df, block_size, num_hash_blocks): ...@@ -154,8 +155,9 @@ def convert_to_mooncake(df, block_size, num_hash_blocks):
DataFrame in mooncake format with columns: timestamp, input_length, output_length, hash_ids DataFrame in mooncake format with columns: timestamp, input_length, output_length, hash_ids
""" """
mooncake_data = [] mooncake_data = []
hasher = RollingHasher() # Initialize once to maintain global state
for _, row in tqdm(df.iterrows(), total=len(df)): for idx, row in tqdm(df.iterrows(), total=len(df)):
# Convert timestamp from seconds to milliseconds (integer) # Convert timestamp from seconds to milliseconds (integer)
timestamp_ms = int(row["Timestamp"] * 1000) timestamp_ms = int(row["Timestamp"] * 1000)
...@@ -166,11 +168,15 @@ def convert_to_mooncake(df, block_size, num_hash_blocks): ...@@ -166,11 +168,15 @@ def convert_to_mooncake(df, block_size, num_hash_blocks):
# Calculate hash array length based on block size # Calculate hash array length based on block size
hash_array_length = math.ceil(input_length / block_size) hash_array_length = math.ceil(input_length / block_size)
# Generate random hash IDs # Generate random content blocks (each block is a tuple of random integers)
hash_ids = [ # Using request index as seed for reproducibility
random.randint(0, num_hash_blocks) for _ in range(hash_array_length) random.seed(idx)
content_blocks = [
(random.randint(0, num_hash_blocks),) for _ in range(hash_array_length)
] ]
hash_ids = hasher(content_blocks)
mooncake_data.append( mooncake_data.append(
{ {
"timestamp": timestamp_ms, "timestamp": timestamp_ms,
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re import re
from typing import Dict, List, Union, cast from typing import Dict, List, Sequence, Union, cast
import numpy as np import numpy as np
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
...@@ -30,6 +18,59 @@ lorem_text = ( ...@@ -30,6 +18,59 @@ lorem_text = (
words = np.array(list(set(re.findall(r"\b[a-zA-Z]+\b", lorem_text)))) words = np.array(list(set(re.findall(r"\b[a-zA-Z]+\b", lorem_text))))
class RollingHasher:
"""
A stateful rolling hasher that converts blocks of content into globally unique hash IDs.
This class maintains a mapping from content hashes to unique integer IDs across multiple
sequences. Each block's hash depends on its content and the hash of the previous block
(rolling/chained hashing).
Usage:
hasher = RollingHasher()
hash_ids = hasher(blocks) # blocks is List[List[int]] or List[tuple]
"""
def __init__(self):
"""Initialize the hasher with empty state."""
self.hash_to_int: Dict[int, int] = {}
self.next_int = 0
def __call__(self, blocks: Sequence[Sequence[int]]) -> List[int]:
"""
Convert a sequence of blocks into a sequence of unique hash IDs.
Args:
blocks: Sequence of blocks, where each block is a sequence of integers
Returns:
List of integer hash IDs, one per block
"""
parent_hash = 0
hashes: List[int] = []
for block in blocks:
# Convert block to tuple for hashing
block_tuple = tuple(block) if not isinstance(block, tuple) else block
combined = (parent_hash, hash(block_tuple))
global_hash = hash(combined)
# Map global_hash to a unique integer
if global_hash not in self.hash_to_int:
self.hash_to_int[global_hash] = self.next_int
self.next_int += 1
hashes.append(self.hash_to_int[global_hash])
parent_hash = global_hash
return hashes
def reset(self):
"""Reset the hasher state (clear all mappings)."""
self.hash_to_int.clear()
self.next_int = 0
def texts_to_hashes( def texts_to_hashes(
tokenizer: Union[str, PreTrainedTokenizerBase], tokenizer: Union[str, PreTrainedTokenizerBase],
texts: List[str], texts: List[str],
...@@ -64,30 +105,15 @@ def texts_to_hashes( ...@@ -64,30 +105,15 @@ def texts_to_hashes(
# batch_encoding["input_ids"] is a List[List[int]] # batch_encoding["input_ids"] is a List[List[int]]
all_tokens: List[List[int]] = batch_encoding["input_ids"] all_tokens: List[List[int]] = batch_encoding["input_ids"]
# Initialize the rolling hasher
hasher = RollingHasher()
results: List[List[int]] = [] results: List[List[int]] = []
hash_to_int: Dict[int, int] = {}
next_int = 0
for tokens in all_tokens: for tokens in all_tokens:
blocks: List[List[int]] = [ blocks: List[List[int]] = [
tokens[i : i + block_size] for i in range(0, len(tokens), block_size) tokens[i : i + block_size] for i in range(0, len(tokens), block_size)
] ]
hashes = hasher(blocks)
parent_hash = 0
hashes: List[int] = []
for block in blocks:
combined = (parent_hash, hash(tuple(block)))
global_hash = hash(combined)
# Map global_hash to a unique integer
if global_hash not in hash_to_int:
hash_to_int[global_hash] = next_int
next_int += 1
hashes.append(hash_to_int[global_hash])
parent_hash = global_hash
results.append(hashes) results.append(hashes)
return results return results
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from collections import Counter from collections import Counter
...@@ -37,7 +25,9 @@ class PrefixAnalyzer: ...@@ -37,7 +25,9 @@ class PrefixAnalyzer:
self.dataset = self._load_dataset() self.dataset = self._load_dataset()
self.hash_counter = self._build_hash_counter() self.hash_counter = self._build_hash_counter()
self.repeated_hash_ids = { self.repeated_hash_ids = {
hash_id for hash_id, count in self.hash_counter.items() if count > 1 (pos, hash_id)
for (pos, hash_id), count in self.hash_counter.items()
if count > 1
} }
def _load_dataset(self) -> list: def _load_dataset(self) -> list:
...@@ -50,11 +40,12 @@ class PrefixAnalyzer: ...@@ -50,11 +40,12 @@ class PrefixAnalyzer:
return dataset return dataset
def _build_hash_counter(self) -> Counter: def _build_hash_counter(self) -> Counter:
all_hash_ids = [] all_hash_positions = []
for item in self.dataset: for item in self.dataset:
all_hash_ids.extend(item["hash_ids"]) for pos, hash_id in enumerate(item["hash_ids"]):
counter = Counter(all_hash_ids) all_hash_positions.append((pos, hash_id))
print(f"Hash counter built: {len(counter)} unique hash IDs") counter = Counter(all_hash_positions)
print(f"Hash counter built: {len(counter)} unique (position, hash_id) pairs")
return counter return counter
def analyze(self) -> dict[str, list]: def analyze(self) -> dict[str, list]:
...@@ -77,14 +68,19 @@ class PrefixAnalyzer: ...@@ -77,14 +68,19 @@ class PrefixAnalyzer:
hash_ids = item["hash_ids"] hash_ids = item["hash_ids"]
assert len(hash_ids) * self.block_size >= input_len assert len(hash_ids) * self.block_size >= input_len
# Special case: if all hash IDs in the row are repeated elsewhere # Special case: if all (position, hash_id) pairs in the row are repeated elsewhere
if all(hash_id in self.repeated_hash_ids for hash_id in hash_ids): if all(
(pos, hash_id) in self.repeated_hash_ids
for pos, hash_id in enumerate(hash_ids)
):
prefix_len = input_len # Set prefix length to input length prefix_len = input_len # Set prefix length to input length
user_prompt_len = 0 # Set user prompt length to 0 user_prompt_len = 0 # Set user prompt length to 0
else: else:
# Count how many hash IDs in this row are repeated elsewhere in the dataset # Count how many (position, hash_id) pairs in this row are repeated elsewhere in the dataset
repeated_count = sum( repeated_count = sum(
1 for hash_id in hash_ids if hash_id in self.repeated_hash_ids 1
for pos, hash_id in enumerate(hash_ids)
if (pos, hash_id) in self.repeated_hash_ids
) )
prefix_len = repeated_count * self.block_size prefix_len = repeated_count * self.block_size
user_prompt_len = input_len - prefix_len user_prompt_len = input_len - prefix_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment