"include/vscode:/vscode.git/clone" did not exist on "0a7a4080a89b1582d673e7f23e3a37c4349f1432"
run_ranker.py 5.94 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import logging

import datasets
from dataclasses import asdict
from transformers import (
    HfArgumentParser,
)
from src.retrieval import CrossEncoder
from src.retrieval.metrics import RetrievalMetric
from src.retrieval.trainer import RetrievalTrainer, EarlyExitCallBack
from src.retrieval.args import RankerArgs, RetrievalTrainingArgs
from src.retrieval.data import RetrievalDataset, RetrievalDataCollator, SameDatasetTrainDataset, TASK_CONFIG
from src.utils.util import FileLogger, makedirs

logger = logging.getLogger(__name__)


def main():
    parser = HfArgumentParser((RankerArgs, RetrievalTrainingArgs))
    model_args, training_args = parser.parse_args_into_dataclasses()
    model_args: RankerArgs
    training_args: RetrievalTrainingArgs

    # set to rerank
    training_args.eval_method = "rerank"

    config = TASK_CONFIG[model_args.version]
    instruction = config["instruction"]

    if model_args.ranker_method == "cross-encoder":
        model = CrossEncoder(
            ranker=model_args.ranker,
            # NOTE: the fp16 model cannot be trained
            # dtype="fp32" if model_args.train_data is not None else model_args.dtype,
            dtype=model_args.dtype,
            cache_dir=model_args.model_cache_dir, 
        )
        cross = True
    else:
        raise NotImplementedError(f"Ranker method {model_args.ranker_method} not implemented!")

    if training_args.use_train_config:
        model.train_config = config["training"]

    tokenizer = model.tokenizer

    with training_args.main_process_first():
        train_dataset, task_indices_range = RetrievalDataset.prepare_train_dataset(
            data_file=model_args.train_data, 
            cache_dir=model_args.dataset_cache_dir,
            add_instruction=model_args.add_instruction,
            train_group_size=training_args.train_group_size,
            config=config,
            use_train_config=training_args.use_train_config,
            select_positive=training_args.select_positive,
            select_negative=training_args.select_negative,
            max_sample_num=training_args.max_sample_num,
            teacher_scores_margin=training_args.teacher_scores_margin,
            teacher_scores_min=training_args.teacher_scores_min,
        )

        # we should get the evaluation task before specifying instruction
        if model_args.eval_data is not None and model_args.add_instruction:
            raw_eval_dataset = datasets.load_dataset('json', data_files=model_args.eval_data, split='train', cache_dir=model_args.dataset_cache_dir)
            eval_task = raw_eval_dataset[0]["task"]
        else:
            eval_task = None

        eval_dataset = RetrievalDataset.prepare_eval_dataset(
            data_file=model_args.eval_data, 
            cache_dir=model_args.dataset_cache_dir,
            instruction=instruction[eval_task] if eval_task is not None else None,
            eval_method=training_args.eval_method,
        )
        corpus = RetrievalDataset.prepare_corpus(
            data_file=model_args.corpus,
            key_template=model_args.key_template,
            cache_dir=model_args.dataset_cache_dir,
            instruction=instruction[eval_task] if eval_task is not None else None 
        )
    
    if training_args.process_index == 0:
        # NOTE: this corpus is for computing metrics, where no instruction is given
        no_instruction_corpus = RetrievalDataset.prepare_corpus(
            data_file=model_args.corpus,
            key_template=model_args.key_template,
            cache_dir=model_args.dataset_cache_dir,
        )
    else:
        no_instruction_corpus = None

    if training_args.inbatch_same_dataset is not None:
        assert training_args.dataloader_num_workers == 0, f"Make sure dataloader num_workers is 0 when using inbatch_same_dataset!"
        train_dataset = SameDatasetTrainDataset(
            train_dataset, 
            task_indices_range, 
            batch_size=training_args.per_device_train_batch_size, 
            seed=training_args.seed, 
            organize_method=training_args.inbatch_same_dataset, 
            num_processes=training_args.world_size,
            process_index=training_args.process_index,
        )
        training_args.per_device_train_batch_size = 1
    
    if training_args.early_exit_steps is not None:
        callbacks = [EarlyExitCallBack(training_args.early_exit_steps)]
    else:
        callbacks = []

    trainer = RetrievalTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        callbacks=callbacks,
        corpus=corpus,
        model_args=model_args,
        data_collator=RetrievalDataCollator(
            tokenizer=tokenizer,
            query_max_length=model_args.query_max_length,
            key_max_length=model_args.key_max_length,
            inbatch_same_dataset=training_args.inbatch_same_dataset,
            cross=cross
        ),
        compute_metrics=RetrievalMetric.get_metric_fn(
            model_args.metrics,
            # for collecting labels
            eval_data=model_args.eval_data,
            cutoffs=model_args.cutoffs,
            # for collecting positives and collating retrieval results
            save_name=model_args.save_name,
            output_dir=training_args.output_dir,
            save_to_output=model_args.save_to_output,
            # for restoring text from indices when collating results
            corpus=no_instruction_corpus,
            max_neg_num=model_args.max_neg_num,
            # for nq metrics
            cache_dir=model_args.dataset_cache_dir,
            # for collate_neg
            filter_answers=model_args.filter_answers
        ),
        file_logger=FileLogger(makedirs(training_args.log_path)),
    )
    # tie accelerators
    model.accelerator = trainer.accelerator

    # Training
    if train_dataset is not None:
        trainer.train()
        return

    if eval_dataset is not None:
        trainer.evaluate()

if __name__ == "__main__":
    main()