preprocess.py 11.1 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# Copyright (c) Alibaba, Inc. and its affiliates.
import ast
from typing import Any, Callable, Dict, List, Literal, Optional, Union

from datasets import Dataset as HfDataset
from tqdm import tqdm

from .template import History

PreprocessFunc = Callable[[HfDataset], HfDataset]


class SwiftPreprocessor:

    def __call__(self, dataset: HfDataset) -> HfDataset:
        if 'history' in dataset.features:
            old_history = dataset['history']
            has_history = False
            history: List[History] = []
            for h in tqdm(old_history):
                if isinstance(h, str):
                    h = ast.literal_eval(h)
                elif h is None:
                    h = []
                if len(h) > 0:
                    has_history = True
                history.append(h)
            dataset = dataset.remove_columns(['history'])
            if has_history:
                dataset = dataset.add_column('history', history)
        if 'system' in dataset.features:
            system = dataset['system']
            has_system = len([sys for sys in system if sys not in {None, ''}]) > 0
            if not has_system:
                dataset = dataset.remove_columns(['system'])
        return dataset


class AlpacaPreprocessor:

    def __init__(self, concat_inst_inp: Optional[Callable[[str, str], str]] = None):
        self.concat_inst_inp = concat_inst_inp

    def __call__(self, dataset: HfDataset) -> HfDataset:
        query: List[str] = []
        response = []
        system = None
        history = None
        tools = None
        for i, d in enumerate(tqdm(dataset)):
            inst, inp = d['instruction'], d.get('input', None)
            h, output = d.pop('history', None), d['output']
            sys = d.pop('system', None)
            tool = d.pop('tools', None)
            if history is None and h is not None:
                history = [None for _ in range(i - 1)]
            if system is None and sys is not None:
                system = [None for _ in range(i - 1)]
            if tools is None and tool is not None:
                tools = [None for _ in range(i - 1)]
            if output is None:
                continue
            if inp is None or len(inp) == 0:
                q = inst
            elif self.concat_inst_inp is not None:
                q = self.concat_inst_inp(inst, inp)
            else:
                q = f'{inst}\n{inp}'
            query.append(q)
            response.append(output)
            if history is not None:
                history.append(h)
            if system is not None:
                system.append(sys)
            if tools is not None:
                tools.append(tool)
        d_dict = {'query': query, 'response': response}
        if history is not None:
            d_dict['history'] = history
        if system is not None:
            d_dict['system'] = system
        if tools is not None:
            d_dict['tools'] = tools
        dataset = HfDataset.from_dict(d_dict)
        return dataset


def _default_repair_conversations(s: Union[str, Any]) -> Any:
    if isinstance(s, str):
        return ast.literal_eval(s)
    return s


class ConversationsPreprocessor:

    def __init__(self,
                 user_role: str = 'user',
                 assistant_role: str = 'assistant',
                 system_role: str = 'system',
                 conversations_key: str = 'conversations',
                 from_key: str = 'from',
                 value_key: str = 'value',
                 repair_conversations: Callable[[Union[str, Dict[str, str]]],
                                                Optional[Dict[str, str]]] = _default_repair_conversations,
                 error_strategy: Literal['delete', 'raise'] = 'raise'):
        self.user_role = user_role
        self.assistant_role = assistant_role
        self.system_role = system_role
        self.conversations_key = conversations_key
        self.from_key = from_key
        self.value_key = value_key
        self.repair_conversations = repair_conversations
        self.error_strategy = error_strategy

    def __call__(self, dataset: HfDataset) -> HfDataset:
        query: List[str] = []
        response: List[str] = []
        system: List[Optional[str]] = []
        tools: List[List[Dict[str, Any]]] = []
        has_system = False
        history: List[History] = []
        has_history = False
        has_tools = False

        for d in tqdm(dataset):
            try:
                tool = d.get('tools', [])
                if len(tool) > 0:
                    has_tools = True
                tools.append(tool)
                conversations = d[self.conversations_key]
                conversations = self.repair_conversations(conversations)
                if conversations is None:
                    continue
                lo = 0
                sys = None
                h: History = []
                assert len(conversations) >= 2
                if conversations[0][self.from_key] == self.system_role:
                    has_system = True
                    lo += 1
                    sys = conversations[0][self.value_key]
                assert conversations[-2][self.from_key] == self.user_role
                assert conversations[-1][self.from_key] == self.assistant_role

                for q, r in zip(conversations[lo:-2:2], conversations[lo + 1:-2:2]):
                    assert q[self.from_key] == self.user_role
                    assert r[self.from_key] == self.assistant_role
                    h.append([q[self.value_key], r[self.value_key]])
                if len(h) > 0:
                    has_history = True
                query.append(conversations[-2][self.value_key])
                response.append(conversations[-1][self.value_key])
                system.append(sys)
                history.append(h)
            except (AssertionError, SyntaxError):
                if self.error_strategy == 'raise':
                    raise ValueError(f'conversations: {conversations}')
        kwargs = {}
        if has_system:
            kwargs['system'] = system
        kwargs.update({
            'query': query,
            'response': response,
        })
        if has_history:
            kwargs['history'] = history
        if has_tools:
            kwargs['tools'] = tools
        dataset = HfDataset.from_dict(kwargs)
        return dataset


class ComposePreprocessor:

    def __init__(self, preprocessor_list: List[PreprocessFunc]) -> None:
        self.preprocessor_list = preprocessor_list

    def __call__(self, dataset: HfDataset) -> HfDataset:
        for preprocessor in self.preprocessor_list:
            dataset = preprocessor(dataset)
        return dataset


class RenameColumnsPreprocessor:

    def __init__(self, rename_mapping: Dict[str, str]) -> None:
        self.rename_mapping = rename_mapping

    def __call__(self, dataset: HfDataset) -> HfDataset:
        for old_name, new_name in self.rename_mapping.items():
            dataset = dataset.rename_column(old_name, new_name)
        return dataset


def preprocess_sharegpt(dataset: HfDataset) -> HfDataset:
    query = []
    response = []
    system: List[Optional[str]] = []
    has_system = False
    history: List[History] = []
    has_history = False
    for d in tqdm(dataset):
        if isinstance(d['conversation'], str):
            try:
                conversation = ast.literal_eval(d['conversation'])
            except SyntaxError:
                continue
        else:
            conversation = d['conversation']
        query.append(conversation[-1]['human'])
        response.append(conversation[-1]['assistant'])
        h = []
        for c in conversation[:-1]:
            h.append([c['human'], c['assistant']])
        if len(h) > 0:
            has_history = True
        history.append(h)
        sys = d.get('system')
        if sys is not None:
            has_system = True
        system.append(sys)
    kwargs = {'query': query, 'response': response}
    if has_history:
        kwargs['history'] = history
    if has_system:
        kwargs['system'] = system
    return HfDataset.from_dict(kwargs)


class SmartPreprocessor:

    def __init__(self) -> None:
        self.preprocessor_mapping = {
            'swift': {
                'required': ['response'],
                'preprocessor': SwiftPreprocessor()
            },
            'alpaca': {
                'required': ['instruction', 'output'],
                'preprocessor': AlpacaPreprocessor()
            },
            'conversations': {  # qwen
                'required': ['conversations'],
                'preprocessor': ConversationsPreprocessor()
            },
            'chatml': {
                'required': ['messages'],
                'preprocessor':
                ConversationsPreprocessor(conversations_key='messages', from_key='role', value_key='content')
            },
            'sharegpt': {
                'required': ['conversation'],
                'preprocessor': preprocess_sharegpt
            }
        }

    def _get_preprocessor(self, dataset: HfDataset) -> PreprocessFunc:
        keys = set(dataset.features.keys())
        required_keys_mapping = {k: v['required'] for k, v in self.preprocessor_mapping.items()}
        for k, required_keys in required_keys_mapping.items():
            if len(set(required_keys) - keys) == 0:
                return self.preprocessor_mapping[k]['preprocessor']
        raise ValueError(f"""dataset.features.keys(): {dataset.features.keys()}
required_keys_mapping: {required_keys_mapping}""")

    def __call__(self, dataset: HfDataset) -> HfDataset:
        preprocessor = self._get_preprocessor(dataset)
        return preprocessor(dataset)


class TextGenerationPreprocessor:

    def __init__(self, prompt: str, query_key: str = 'query', response_key: str = 'response') -> None:
        self.prompt = prompt
        self.query_key = query_key
        self.response_key = response_key

    def __call__(self, dataset: HfDataset) -> HfDataset:
        query = []
        for d in tqdm(dataset):
            query.append(self.prompt.format(query=d[self.query_key]))
        return HfDataset.from_dict({'query': query, 'response': dataset[self.response_key]})


class ClsPreprocessor:

    def __init__(self, labels: List[str], task_name: str, is_pair_seq: bool = False) -> None:
        self.labels = labels
        category = ', '.join(labels)
        if is_pair_seq:
            inputs = 'Sentence1: {sentence1}\nSentence2: {sentence2}'
        else:
            inputs = 'Sentence: {sentence}'
        self.prompt = f"""Task: {task_name}
{inputs}
Category: {category}
Output:"""
        self.task_name = task_name
        self.is_pair_seq = is_pair_seq

    def __call__(self, dataset: HfDataset) -> HfDataset:
        query = []
        response = []
        for d in tqdm(dataset):
            if d['label'] is None:  # ignore dataset error
                continue
            if self.is_pair_seq:
                q = self.prompt.format(sentence1=d['sentence1'], sentence2=d['sentence2'])
            else:
                q = self.prompt.format(sentence=d['sentence'])
            query.append(q)
            response.append(self.labels[int(d['label'])])
        return HfDataset.from_dict({'query': query, 'response': response})