easy_dataset.py 10 KB
Newer Older
1
import copy
2
import json
3
from typing import Dict, Sequence
4
5

import torch
6
7
8
from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
9
from transformers import AutoTokenizer
10
11
12
13

IGNORE_INDEX = -100


14
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=max_length,
            truncation=True,
        ) for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


37
def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
38
39
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
40
41
42
    examples_tokenized, sources_tokenized = [
        _tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
    ]
43
44
45
46
47
48
49
50
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)


class EasySupervisedDataset(Dataset):
51
52
53
54

    def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
        super(EasySupervisedDataset, self).__init__()
        with open(data_file, "r", encoding="UTF-8") as f:
55
56
            all_lines = f.readlines()
        #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
57
        sources, targets = [], []
58
59
60
        for line in all_lines:
            if "回答:" in line:
                sep_index = line.index("回答:")
61
62
                sources.append(line[:sep_index + 3])
                targets.append(line[sep_index + 3:] + tokenizer.eos_token)
63
64
            else:
                sources.append(line)
65
66
                targets.append("" + tokenizer.eos_token)
        data_dict = preprocess(sources, targets, tokenizer, max_length)
67
68
69
70

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        self.data_file = data_file
71

72
73
74
75
76
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])
77

78
79
    def __repr__(self):
        return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
80

81
82
83
    def __str__(self):
        return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"

84

85
class EasyPromptsDataset(Dataset):
86
87
88
89

    def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
        super(EasyPromptsDataset, self).__init__()
        with open(data_file, "r", encoding="UTF-8") as f:
90
            all_lines = f.readlines()
91
            all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
92
        self.prompts = [
93
94
            tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
                      truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
95
96
97
            for line in tqdm(all_lines)
        ]
        self.data_file = data_file
98

99
100
    def __len__(self):
        return len(self.prompts)
101

102
103
    def __getitem__(self, idx):
        return self.prompts[idx]
104

105
106
107
108
109
110
111
112
    def __repr__(self):
        return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"

    def __str__(self):
        return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"


class EasyRewardDataset(Dataset):
113
114
115

    def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
        super(EasyRewardDataset, self).__init__()
116
117
118
119
120
121
122
123
        self.chosen = []
        self.reject = []
        if special_token is None:
            self.end_token = tokenizer.eos_token
        else:
            self.end_token = special_token
        print(self.end_token)
        #read all lines in the train_file to a list
124
        with open(train_file, "r", encoding="UTF-8") as f:
125
126
127
            all_lines = f.readlines()
        for line in tqdm(all_lines):
            data = json.loads(line)
128
            prompt = "提问:" + data['prompt'] + " 回答:"
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

            chosen = prompt + data['chosen'] + self.end_token
            chosen_token = tokenizer(chosen,
                                     max_length=max_length,
                                     padding="max_length",
                                     truncation=True,
                                     return_tensors="pt")
            self.chosen.append({
                "input_ids": chosen_token['input_ids'],
                "attention_mask": chosen_token['attention_mask']
            })

            reject = prompt + data['rejected'] + self.end_token
            reject_token = tokenizer(reject,
                                     max_length=max_length,
                                     padding="max_length",
                                     truncation=True,
                                     return_tensors="pt")
            self.reject.append({
                "input_ids": reject_token['input_ids'],
                "attention_mask": reject_token['attention_mask']
            })

    def __len__(self):
        length = len(self.chosen)
        return length

    def __getitem__(self, idx):
        return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
            "input_ids"], self.reject[idx]["attention_mask"]
159

160
161
162
163
164
165
166
    #python representation of the object and the string representation of the object
    def __repr__(self):
        return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"

    def __str__(self):
        return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"

167

168
169
170
171
'''
Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False.
'''
172
173


174
class EasySFTDataset(Dataset):
175
176

    def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
177
178
        super().__init__()
        #read the data_file line by line
179
        with open(data_file, "r", encoding="UTF-8") as f:
180
181
182
183
184
185
            #encode the text data line by line and put raw python list input_ids only to raw_input_ids list
            raw_input_ids = []
            for line in f:
                encoded_ids = tokenizer.encode(line)
                #if the encoded_ids is longer than max_length, then split it into several parts
                if len(encoded_ids) > max_length:
186
187
                    for i in range(0, len(encoded_ids), max_length):
                        raw_input_ids.append(encoded_ids[i:i + max_length])
188
189
                else:
                    raw_input_ids.append(encoded_ids)
190

191
192
193
194
195
196
197
198
199
200
201
        grouped_inpup_ids = []
        current_input_ids = []
        attention_mask = []
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id
        if is_group_texts:
            for input_ids in raw_input_ids:
                if len(current_input_ids) + len(input_ids) > max_length:
                    #pad the current_input_ids to max_length with tokenizer.pad_token_id
                    padded_length = max_length - len(current_input_ids)
                    current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
202
203
204
                    grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
                    attention_mask.append(
                        torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
205
206
207
208
209
210
                    current_input_ids = []
                else:
                    current_input_ids.extend(input_ids)
            if len(current_input_ids) > 0:
                padded_length = max_length - len(current_input_ids)
                current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
211
212
213
                grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
                attention_mask.append(
                    torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
214
215
216
217
218
        else:
            #just append the raw_input_ids to max_length
            for input_ids in raw_input_ids:
                padded_length = max_length - len(input_ids)
                input_ids.extend([tokenizer.pad_token_id] * padded_length)
219
220
221
                attention_mask.append(
                    torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
                grouped_inpup_ids.append(torch.tensor(input_ids, dtype=torch.long))
222
223
224
225
        self.input_ids = grouped_inpup_ids
        self.labels = copy.deepcopy(self.input_ids)
        self.file_name = data_file
        self.attention_mask = attention_mask
226

227
228
    def __len__(self):
        return len(self.input_ids)
229

230
    #get item from dataset
231
232
233
    def __getitem__(self, idx):
        return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])

234
235
236
    #generate the dataset description to be printed by print in python
    def __repr__(self):
        return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
237

238
239
240
    #generate the dataset description to be printed by print in python
    def __str__(self):
        return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"