"python-package/vscode:/vscode.git/clone" did not exist on "76c44d786af7fba3a5a5ed073f19be909dce1f40"
data.py 25.5 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
import math
import torch
import random
import datasets
import numpy as np
from glob import glob
from string import Formatter
from typing import Optional, Tuple, Union, List, Callable, Dict, Any, Mapping
from copy import deepcopy
from dataclasses import dataclass
from collections import defaultdict
from transformers.tokenization_utils import PreTrainedTokenizer
from ..utils.util import get_max_length_in_nested_lists, pad_nested_lists, split_file_dir_name_ext, DatasetProcessFn


class RetrievalDataset:
    def get_train_process_fn(train_group_size=8, select_positive="first", select_negative="random", teacher_scores_margin=None, teacher_scores_min=None, stable_distill=False, instruction=None):
        @DatasetProcessFn()
        def _process(query:str, task:str, pos:List[str]=None, neg:List[str]=None, history:List[str]=None, teacher_scores:Optional[List[float]]=None, **kwds):
            output = {}
            keys = []
            if history is not None:
                pos = []
                neg = history

            # filter based on teacher scores
            if teacher_scores is not None:
                assert len(teacher_scores) == len(pos) + len(neg), f"Found incompatible teacher_score size ({len(teacher_scores)}) and positive size ({len(pos)}) negative size ({len(neg)})"
                if teacher_scores_min is not None:
                    max_score = max(teacher_scores)
                    if max_score < teacher_scores_min:
                        return None
                if teacher_scores_margin is not None:
                    max_score = max(teacher_scores)
                    min_score = min(teacher_scores)
                    if max_score - min_score < teacher_scores_margin:
                        return None

            pos_num = len(pos)
            if select_positive == "random":
                assert pos_num > 0, f"Select positive strategy 'random' is only available when there is a given positive!"
                pos_idx = random.choice(range(pos_num))
                pos = pos[pos_idx]
            elif teacher_scores is not None and select_positive == "teacher":
                pos_idx = max(enumerate(teacher_scores), key=lambda x: x[1])[0]
                if pos_idx < pos_num:
                    pos = pos[pos_idx]
                else:
                    # pos is selected from neg, thus we remove it from neg
                    pos = neg.pop(pos_idx - pos_num)
            elif teacher_scores is not None and select_positive == "teacher-pos":
                assert pos_num > 0, f"Select positive strategy 'teacher-pos' is only available when there are teacher_scores and positives!"
                pos_scores = teacher_scores[:pos_num]
                pos_idx = max(enumerate(pos_scores), key=lambda x: x[1])[0]
                pos = pos[pos_idx]
            else:
                # NOTE: default to select the first positive
                assert pos_num > 0, f"Select positive strategy 'first' is only available when there is a given positive!"
                pos_idx = 0
                pos = pos[0]

            if teacher_scores is not None:
                if pos_idx >= pos_num:
                    # only makes sense when select_positive==teacher
                    # remove the selected score
                    pos_score = teacher_scores.pop(pos_idx)
                else:
                    pos_score = teacher_scores[pos_idx]
                # remove teacher scores of unused positives
                neg_scores = teacher_scores[pos_num:]
                return_teacher_scores = [pos_score]

            keys.append(pos)

            if len(neg) == 0:
                return None
            elif len(neg) < train_group_size - 1:
                num = math.ceil((train_group_size - 1) / len(neg))
                neg = neg * num
                if teacher_scores is not None:
                    neg_scores = neg_scores * num

            if teacher_scores is not None and select_negative == "teacher-":
                neg_indices = [i for i, _ in sorted(enumerate(neg_scores), key=lambda x: x[1])[:train_group_size - 1]]
            elif teacher_scores is not None and select_negative == "teacher+":
                neg_indices = [i for i, _ in sorted(enumerate(neg_scores), key=lambda x: x[1], reverse=True)[:train_group_size - 1]]
            elif select_negative == "first":
                neg_indices = list(range(len(neg)))[:train_group_size - 1]
            else:
                # NOTE: default to select random negatives
                neg_indices = random.sample(range(len(neg)), train_group_size - 1)
            for neg_idx in neg_indices:
                keys.append(neg[neg_idx])
                if teacher_scores is not None:
                    return_teacher_scores.append(neg_scores[neg_idx])

            if instruction is not None:
                query = instruction["query"] + query
                keys = [instruction["key"] + key for key in keys]

            output = {
                "query": query,
                "key": keys,
                "task": task,
            }
            if teacher_scores is not None:
                output["teacher_scores"] = return_teacher_scores

            if stable_distill:
                # when using stable_distill, we must sort teacher_scores descendingly
                neg_score = output["teacher_scores"][1:]
                neg = output["key"][1:]
                pairs = sorted(list(zip(neg, neg_score)), key=lambda x: x[1], reverse=True)
                neg = [pair[0] for pair in pairs]
                neg_score = [pair[1] for pair in pairs]
                output["key"][1:] = neg
                output["teacher_scores"][1:] = neg_score

            return output
        return _process

    def prepare_train_dataset(data_file=None, cache_dir=None, config=None, train_group_size=8, select_positive="first", select_negative="random", max_sample_num=None, teacher_scores_margin=None, teacher_scores_min=None, stable_distill=False, add_instruction=False, instruction=None, use_train_config=False):
        if data_file is None:
            return None, None

        if isinstance(data_file, str):
            if "*" in data_file:
                data_file = glob(data_file)
            else:
                data_file = [data_file]

        train_datasets = []
        offset = 0
        dataset_indices_range = {}
        dataset_dup = defaultdict(int)

        for path in data_file:
            temp_dataset = datasets.load_dataset('json', data_files=path, split='train', cache_dir=cache_dir)
            task = temp_dataset[0]["task"]
            directory, _, _ = split_file_dir_name_ext(path)
            dataset_name = directory.name

            if add_instruction:
                instruction = config["instruction"][task]
            
            if use_train_config:
                train_config = config["training"][task]
                select_positive = train_config["select_positive"]
                select_negative = train_config["select_negative"]
                max_sample_num = train_config["max_sample_num"]
                teacher_scores_margin = train_config["teacher_scores_margin"]
                teacher_scores_min = train_config["teacher_scores_min"]
                stable_distill = train_config["stable_distill"]

            process_fn = RetrievalDataset.get_train_process_fn(
                train_group_size, 
                select_positive=select_positive,
                select_negative=select_negative,
                teacher_scores_margin=teacher_scores_margin,
                teacher_scores_min=teacher_scores_min,
                stable_distill=stable_distill,
                instruction=instruction
            )
            # map to filter
            temp_dataset = temp_dataset.map(process_fn, batched=True, num_proc=32, remove_columns=temp_dataset.column_names)
            # limit sample number
            if max_sample_num is not None and len(temp_dataset) > max_sample_num:
                temp_dataset = temp_dataset.train_test_split(max_sample_num, shuffle=False)["test"]
            train_datasets.append(temp_dataset)

            if dataset_name in dataset_indices_range:
                # NOTE: we allow duplicated dataset to balance the portion of different datasets
                dataset_dup[dataset_name] += 1
                dataset_indices_range[f"{dataset_name}_{dataset_dup[dataset_name]}"] = (offset, offset + len(temp_dataset))
            else:
                dataset_indices_range[dataset_name] = (offset, offset + len(temp_dataset))
            offset += len(temp_dataset)

        dataset = datasets.concatenate_datasets(train_datasets)
        return dataset, dataset_indices_range
    
    @staticmethod
    def prepare_eval_dataset(data_file=None, cache_dir=None, instruction=None, eval_method="retrieve"):
        if data_file is None:
            return None
        @DatasetProcessFn()
        def _process(query:str, query_id:Optional[int]=None, key:Optional[List[str]]=None, key_index: Optional[List[int]]=None, pos: Optional[List[Union[int, str]]]=None, neg: Optional[List[str]]=None, pos_index:Optional[List[int]]=None, neg_index: Optional[List[int]]=None, _index=None, **kwds):
            if instruction is not None:
                query = instruction["query"] + query
            
            if query_id is None:
                assert _index is not None
                query_id = _index

            output = {
                "query": query,
                "query_id": query_id,
                "task": task,
            }

            if eval_method == "rerank":
                # if there is a column named key, it must be the candidates to rerank
                if key is not None:
                    if key_index is not None:
                        output["key_index"] = key_index
                    else:
                        # NOTE: there must be key_index when reranking
                        output["key_index"] = list(range(len(key)))
                # otherwise, default 
                elif pos is not None and neg is not None:
                    key = pos + neg
                    if pos_index is not None:
                        output["key_index"] = pos_index + neg_index
                    else:
                        # NOTE: there must be key_index when reranking
                        output["key_index"] = list(range(len(key)))
                else:
                    raise ValueError(f"Expected either pos/neg or key in the file {data_file}!")

                if instruction is not None:
                    output["key"] = [instruction["key"] + k for k in key]
                else:
                    output["key"] = key
            return output

        dataset = datasets.load_dataset('json', data_files=data_file, split='train', cache_dir=cache_dir)
        if "task" in dataset:
            task = dataset[0]["task"]
        else:
            task = "nan"

        dataset = dataset.map(_process, num_proc=32, batched=True, remove_columns=dataset.column_names, with_indices=True)
        return dataset

    @staticmethod
    def prepare_corpus(data_file, key_template:str, cache_dir=None, instruction=None):
        """Concatenate desired keys by key_template"""
        if data_file is None:
            return None
        keys = Formatter().parse(key_template)
        field_names = [x[1] for x in keys if x[1] is not None]
        @DatasetProcessFn()
        def _process(**kwds):
            inputs = {name: kwds[name] for name in field_names}
            content = key_template.format(**inputs)
            if instruction is not None:
                content = instruction["key"] + content
            return {'content': content}
        dataset = datasets.load_dataset('json', data_files=data_file, split="train", cache_dir=cache_dir)
        dataset.set_transform(_process)
        return dataset


class SameDatasetTrainDataset(torch.utils.data.Dataset):
    """Dataset to yield a batch of data at one time. All samples in the same batch comes from the same task.
    
    Args:
        organize_method: 
            random:
            epoch:
            epoch-random:
            epoch-static
    """
    def __init__(self, dataset, dataset_indices_range, batch_size, seed, organize_method, process_index=0, num_processes=1):
        self.dataset = dataset
        self.batch_size = batch_size
        self.organize_method = organize_method
        self.process_index = process_index
        self.num_processes = num_processes

        self.dataset_indices_range = dataset_indices_range

        self.deterministic_generator = np.random.default_rng(seed)
        # different devices must sample different data batch
        self.nondeterministic_generator = np.random.default_rng(seed + process_index)

        # shuffle the indices
        if "random" in self.organize_method:
            self.sample_range = [np.arange(*x) for x in self.dataset_indices_range.values()]
            for x in self.sample_range:
                # NOTE: we must make sure every processes use the same shuffling order
                self.deterministic_generator.shuffle(x)
    
    def create_epoch(self):
        epoch = []
        for k, x in self.dataset_indices_range.items():
            dataset_range = np.arange(*x)
            # NOTE: we must make sure every processes use the same shuffling order
            self.deterministic_generator.shuffle(dataset_range)
            num_batches, remainer = divmod(len(dataset_range), self.batch_size * self.num_processes)
            # Truncate
            if remainer != 0:
                dataset_range = dataset_range[:num_batches * self.batch_size * self.num_processes]

            batches = dataset_range.reshape(num_batches, self.batch_size * self.num_processes).tolist()
            for i in range(len(batches)):
                batches[i] = (k, batches[i])
            epoch.extend(batches)
        # shuffle among datasets, also make sure different processes share the same shuffling results
        self.deterministic_generator.shuffle(epoch)
        self.epoch = epoch
        self.step = 0
        self.steps_per_epoch = len(epoch)

    def __getitem__(self, idx):        
        if self.organize_method == "random":
            sample_prob = [len(x) / len(self.dataset) for x in self.sample_range]

            dataset_name = self.deterministic_generator.choice(range(len(self.sample_range)), size=1, p=sample_prob)[0]
            sample_range = self.sample_range[dataset_name]

            batch_indices = self.nondeterministic_generator.choice(sample_range, size=self.batch_size, replace=False)
            batch_data = self.dataset[batch_indices.tolist()]

        elif self.organize_method == "epoch":
            if not hasattr(self, "epoch") or self.step > self.steps_per_epoch - 1:
                self.create_epoch()

            dataset_name, batch_indices = self.epoch[self.step]
            batch_indices = batch_indices[self.process_index * self.batch_size: (self.process_index + 1) * self.batch_size]
            batch_data = self.dataset[batch_indices]
            self.step += 1
        
        elif self.organize_method == "epoch-static":
            if not hasattr(self, "epoch"):
                # the data within each batch is static once created
                self.create_epoch()
            
            if self.step > self.steps_per_epoch - 1:
                self.deterministic_generator.shuffle(self.epoch)
                self.step = 0

            dataset_name, batch_indices = self.epoch[self.step]
            batch_indices = batch_indices[self.process_index * self.batch_size: (self.process_index + 1) * self.batch_size]
            batch_data = self.dataset[batch_indices]
            self.step += 1
        
        elif self.organize_method == "epoch-random":
            sample_scope = [len(x) for x in self.sample_range]
            sample_prob = [x / sum(sample_scope) for x in sample_scope]

            dataset_name = self.deterministic_generator.choice(range(len(self.sample_range)), size=1, p=sample_prob)[0]
            sample_range = self.sample_range[dataset_name]

            # sequential sample (the indices are already shuffled)
            batch_indices = sample_range[self.process_index * self.batch_size: (self.process_index + 1) * self.batch_size]
            batch_data = self.dataset[batch_indices.tolist()]
            # update indices
            remaining_indices = sample_range[self.num_processes * self.batch_size:]
            if len(remaining_indices) < self.batch_size * self.num_processes:
                remaining_indices = np.array([])
            self.sample_range[dataset_name] = remaining_indices
            # restore all indices if they are all sampled
            if all(len(x) == 0 for x in self.sample_range):
                self.sample_range = [np.arange(*x) for x in self.dataset_indices_range.values()]
                for x in self.sample_range:
                    self.deterministic_generator.shuffle(x)
        else:
            raise NotImplementedError(f"Organize method {self.organize_method} is not implemented for SameTaskTrainDataset!")

        return batch_data
    
    def __len__(self):
        return len(self.dataset) // self.batch_size


@dataclass
class RetrievalDataCollator:
    """
    """
    tokenizer: PreTrainedTokenizer = None
    query_max_length: int = 256
    key_max_length: int = 256
    inbatch_same_dataset: bool = False
    cross: bool = False

    def __call__(self, batch_elem):
        first_elem = batch_elem[0]
        return_batch = {}
        
        for k, v in first_elem.items():
            if self.inbatch_same_dataset:
                # here the data have already been grouped
                batch_value = batch_elem[0][k]
            else:
                batch_value = [elem[k] for elem in batch_elem]
            
            # collate training/evaluating
            if k == "query":
                query = batch_value
                # NOTE: we do not need the individual query and key when requiring cross data
                if self.cross:
                    continue
                batch_value = self.tokenizer(
                    batch_value,
                    padding=True,
                    truncation=True,
                    max_length=self.query_max_length,
                    return_tensors="pt",
                )
            elif k == "key":
                # in case the keys are of different sizes for different queries when reranking
                max_length = get_max_length_in_nested_lists(batch_value)
                batch_value, key_mask = pad_nested_lists(batch_value, max_length, "", "right")
                batch_value = sum(batch_value, [])
                key = batch_value
                # key_mask assigns 1 to valid keys and 0 to padded keys
                return_batch["key_mask"] = torch.tensor(key_mask)
                # NOTE: we do not need the individual query and key when requiring cross data
                if self.cross:
                    continue
                batch_value = self.tokenizer(
                    batch_value,
                    padding=True,
                    truncation=True,
                    max_length=self.key_max_length,
                    return_tensors="pt",
                )

            elif k == "key_index":
                max_length = get_max_length_in_nested_lists(batch_value)
                batch_value, _ = pad_nested_lists(batch_value, max_length, -1, "right")
                batch_value = torch.tensor(batch_value)

            elif k == "content":
                # collate corpus
                batch_value = self.tokenizer(
                    batch_value,
                    padding=True,
                    truncation=True,
                    max_length=self.key_max_length,
                    return_tensors="pt",
                )

            elif k == "task":
                assert all(v == batch_value[0] for v in batch_value), f"Make sure all samples are of the same task in a batch!"
                batch_value = batch_value[0]

            elif all(v is None for v in batch_value):
                # in case that some data have teacher_scores but others do not
                batch_value = None

            else:
                batch_value = torch.tensor(batch_value)

            return_batch[k] = batch_value                

        if self.cross:
            query_num = len(query)
            key_num = len(key)
            assert key_num % query_num == 0
            group_size = key_num // query_num
            new_query = []
            for i in range(key_num):
                new_query.append(query[i // group_size])

            return_batch["cross"] = self.tokenizer(
                new_query, key, 
                padding=True, 
                truncation=True,
                max_length=self.key_max_length + self.query_max_length,
                return_tensors="pt"
            )
            return_batch["batch_size"] = len(query)

        return return_batch


TASK_CONFIG = {
    "llm-embedder": {
        "instruction": {
            "qa": {
                "query": "Represent this query for retrieving relevant documents: ",
                "key": "Represent this document for retrieval: ",
            },
            "convsearch": {
                "query": "Encode this query and context for searching relevant passages: ",
                "key": "Encode this passage for retrieval: ",
            },
            "chat": {
                "query": "Embed this dialogue to find useful historical dialogues: ",
                "key": "Embed this historical dialogue for retrieval: ",
            },
            "lrlm": {
                "query": "Embed this text chunk for finding useful historical chunks: ",
                "key": "Embed this historical text chunk for retrieval: ",
            },
            "icl": {
                "query": "Convert this example into vector to look for useful examples: ",
                "key": "Convert this example into vector for retrieval: ",
            },
            "tool": {
                "query": "Transform this user request for fetching helpful tool descriptions: ",
                "key": "Transform this tool description for retrieval: "
            },
        },

        "training": {
            "qa": {
                "select_positive": "first",
                "select_negative": "random",
                "max_sample_num": None,
                "teacher_scores_margin": None,
                "teacher_scores_min": None, 
                "contrastive_weight": 0,
                "stable_distill": True,
            },
            "convsearch": {
                "select_positive": "first",
                "select_negative": "random",
                "max_sample_num": None,
                "teacher_scores_margin": None,
                "teacher_scores_min": None,
                "distill_weight": 0,
                "stable_distill": False,
            },
            "chat": {
                "select_positive": "teacher",
                "select_negative": "random",
                "max_sample_num": None,
                "teacher_scores_margin": None,
                "teacher_scores_min": None,
                "distill_weight": 1.0,
                "contrastive_weight": 0,
                "teacher_temperature": 0.1,
                "stable_distill": False,
            },
            "lrlm": {
                "select_positive": "teacher",
                "select_negative": "random",
                "max_sample_num": 10000,
                "teacher_scores_margin": 0.1,
                "teacher_scores_min": None,
                "distill_weight": 1.0,
                "contrastive_weight": 0,
                "teacher_temperature": 0.1,            
                "stable_distill": False,
            },
            "icl": {
                "select_positive": "random",
                "select_negative": "random",
                "max_sample_num": None,
                "teacher_scores_margin": None,
                "teacher_scores_min": None,
                "contrastive_weight": 0,
                "stable_distill": True,
            },
            "tool": {
                "select_positive": "first",
                "select_negative": "random",
                "max_sample_num": None,
                "teacher_scores_margin": None,
                "teacher_scores_min": None,
                "distill_weight": 0,
                "stable_distill": False,
            },
        }
    },

    "bge": {
        "instruction": defaultdict(lambda: {"query": "Represent this sentence for searching relevant passages: ", "key": ""})
    },
    
    "e5": {
        "instruction": defaultdict(lambda: {"query": "query: ", "key": "passage: "})
    },
    
    "instructor": {
        "instruction": {
            "qa": {
                "query": "Represent the query for retrieving supporting documents: ",
                "key": "Represent the document for retrieval: ",
            },
            "convsearch": {
                "query": "Represent the query and context for retrieving supporting passages: ",
                "key": "Represent the passage for retrieval: ",
            },
            "chat": {
                "query": "Represent the dialogue for retrieving useful historical dialogues: ",
                "key": "Represent the historical dialogue for retrieval: ",
            },
            "lrlm": {
                "query": "Represent the text chunk for retrieving useful historical chunks: ",
                "key": "Represent the historical text chunk for retrieval: ",
            },
            "icl": {
                "query": "Represent the example for retrieving duplicate examples: ",
                "key": "Represent the example for retrieval: ",
            },
            "tool": {
                "query": "Represent the user request for retrieving duplicate examples: ",
                "key": "Represent the tool description for retrieval: "
            },
        },
    }
}