from functools import partial
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
SYSTEM_PROMPT = """Respond in the following format:
...
...
"""
def get_tokenization_stats(example, tokenizer=None):
messages = {
"prompt": [
# improves adherence to the system prompt by having it in the user context
{"role": "user", "content": SYSTEM_PROMPT + "\n\n" + example["text"]},
],
}
inputs = tokenizer.apply_chat_template(
messages["prompt"], tokenize=True, add_generation_prompt=True
)
return {
"input_ids": inputs,
}
def get_dataset_lengths(dataset):
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
def main():
ds = load_dataset("skrishna/gsm8k_only_answer", split="train")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
stats = partial(get_tokenization_stats, tokenizer=tokenizer)
ds = ds.map(stats, remove_columns=["text", "label"])
max_input_len = np.max(get_dataset_lengths(ds))
print(f"Max input length: {max_input_len}")
if __name__ == "__main__":
main()