generate_tokenized_dataset.py 2.38 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse

from datasets import load_dataset
from transformers import AutoTokenizer


def prepare_dataset(tokenizer, text_file_path: str):
    """
    Tokenizes a text file where each line is a different example.
    Padding is applied to each example.
    """
    # Each line is a different example
    dataset = load_dataset("text", data_files={"train": text_file_path})

    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
    return tokenized_dataset["train"]


def generate_tokenized_dataset(tokenizer_path: str, text_file_path: str, output_dir: str) -> None:
    """
    Generate tokenized dataset from a text file, where each line is a different example.

    Args:
        tokenizer_path (str): Path to the directory containing the tokenizer files.
        text_file_path (str): Path to the text file to tokenize.
        output_dir (str): Directory where the tokenized dataset will be saved
    """
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    tokenizer.pad_token = tokenizer.eos_token

    train_dataset = prepare_dataset(tokenizer, text_file_path)
    train_dataset.save_to_disk(output_dir)


if __name__ == "__main__":
    # Example usage:
    # python generate_tokenized_dataset.py --tokenizer_path /shared/public/models/Mistral-7B --text_file_path ./../../resources/tiny_shakespeare.txt --output_dir ./../../resources/tiny_shakespeare_tokenized
    parser = argparse.ArgumentParser(description="Generate tokenized dataset from a text file.")

    # Add arguments
    parser.add_argument(
        "--tokenizer_path",
        type=str,
        required=True,
        help="Path to the directory containing the tokenizer files.",
    )
    parser.add_argument(
        "--text_file_path",
        type=str,
        required=True,
        help="Path to the text file to tokenize.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
        help="Directory where the tokenized dataset will be saved.",
    )

    # Parse the arguments
    args = parser.parse_args()

    # Call the function with parsed arguments
    generate_tokenized_dataset(
        tokenizer_path=args.tokenizer_path,
        text_file_path=args.text_file_path,
        output_dir=args.output_dir,
    )