"vscode:/vscode.git/clone" did not exist on "685168f5a1679fb643d739aaecb48d09918ddf20"
prepare_datasets.py 5.23 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import datetime
import glob
import numpy as np
import os.path
import pandas as pd
from allamo.logging import configure_logger, logger
import torch

EOS_TOKEN = "</s>"

def init_tokenizer(output_dir, tiktoken_tokenizer_name, hf_tokenizer_path):
    vocab_size = None
    if hf_tokenizer_path:
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_path)
        if vocab_size is None:
            vocab_size = len(tokenizer)
        logger.info(f"HuggingFace {hf_tokenizer_path} tokenizer loaded with the vocab size {vocab_size}")
    elif tiktoken_tokenizer_name:
        import tiktoken
        tokenizer = tiktoken.get_encoding(tiktoken_tokenizer_name)
        if vocab_size is None:
            vocab_size = tokenizer.max_token_value + 1 # values start from 0
        logger.info(f"Tiktoken {tiktoken_tokenizer_name} tokenizer loaded with the vocab size {vocab_size}")
    else:
        raise Exception('Tokenizer is not provided. Please specify either a Tiktoken tokenizer or a HuggingFace tokenizer')
    return tokenizer

def load_list_of_txt_files(index_file_path, input_data_dir, data_split):
    if index_file_path:
        # Load the csv file into a pandas dataframe
        index_df = pd.read_csv(index_file_path)

        # Replace any NaN values in the "File" column with an empty string
        index_df['File'] = index_df['File'].fillna('')

        # Filter the dataframe to only include rows where the "File" column ends with ".txt"
        txt_files_df = index_df[(index_df['File'].str.endswith('.txt'))]
        if 'Split' not in txt_files_df.columns:
            txt_files_df['Split'] = data_split if data_split else 'train'
    elif input_data_dir:
        txt_files = glob.glob(os.path.join(input_data_dir, "*.txt"))
        txt_files_df = pd.DataFrame({'File': txt_files})
        txt_files_df['Split'] = data_split if data_split else 'train'
    else:
        raise Exception('Either an index file or an input data dir must be provided')
    
    logger.info(f"{len(txt_files_df)} txt files found to process")
    return txt_files_df
    

def encode_file(input_file, output_file, tokenizer):
    enc_data = tokenizer.encode(input_file.read())
    enc_data = np.array(enc_data, dtype=np.uint16)
    enc_data.tofile(output_file)
    # torch.save(enc_data, output_file)
    tokens = len(enc_data)
    
    enc_data = tokenizer.encode(EOS_TOKEN)
    enc_data = np.array(enc_data, dtype=np.uint16)
    enc_data.tofile(output_file)
    # torch.save(enc_data, output_file)
    tokens += len(enc_data)
    
    return tokens


def create_datasets(txt_files_df, tokenizer, input_data_dir, output_data_dir):
    train_tokens = 0
    val_tokens = 0
    files_cnt = 0
    train_file_path = os.path.join(output_data_dir, 'train.bin')
    val_file_path = os.path.join(output_data_dir, 'val.bin')
    # train_file_path = os.path.join(output_data_dir, 'train.pt')
    # val_file_path = os.path.join(output_data_dir, 'val.pt')
    with open(train_file_path, 'wb+') as train_file, open(val_file_path, 'wb+') as val_file:
        # Process each of the txt files
        for _, row in txt_files_df.iterrows():
            filename = os.path.join(input_data_dir, row['File']) if input_data_dir else row['File']
            if not os.path.isfile(filename):
                logger.info(f"File {filename} does not exist.")
                continue
            with open(filename, 'r', encoding="utf-8") as txt_file:
                logger.info(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - Start processing {filename}")
                if row['Split'] == 'test':
                    tokens = encode_file(txt_file, val_file, tokenizer)
                    val_tokens += tokens
                else:
                    tokens = encode_file(txt_file, train_file, tokenizer)
                    train_tokens += tokens
                logger.info(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - {filename} added ({tokens} tokens) to the {row['Split']} dataset")
                files_cnt += 1
                    
    total_tokens = train_tokens + val_tokens
    logger.info(f"Datasets created in {output_data_dir} from {files_cnt} files. Tokens: {total_tokens:,} (Train: {train_tokens:,} Val: {val_tokens:,})")

if __name__ == '__main__':
    configure_logger()
    parser = argparse.ArgumentParser(description='Prepare your datasets')
    parser.add_argument('--index_file', type=str, help='Path to an index file')
    parser.add_argument('--input_data_dir', type=str, help='Path to a directory with txt files')
    parser.add_argument('--data_split', type=str, default='train', choices=['train', 'test'], help='Data split')
    parser.add_argument('--output_data_dir', type=str, required=True, help='Path to a directory for output dataset files')
    parser.add_argument('--tiktoken_tokenizer_name', type=str, help='Tiktoken tokenizer name')
    parser.add_argument('--hf_tokenizer_path', type=str, help='HuggingFace tokenizer path')
    
    args = parser.parse_args()
    tokenizer = init_tokenizer(args.output_data_dir, args.tiktoken_tokenizer_name, args.hf_tokenizer_path)
    txt_files_df = load_list_of_txt_files(args.index_file, args.input_data_dir, args.data_split)
    create_datasets(txt_files_df, tokenizer, args.input_data_dir, args.output_data_dir)