run_dense.py 6.06 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import logging
import torch

import datasets
from dataclasses import asdict
from transformers import (
    HfArgumentParser,
)
from src.retrieval import DenseRetriever
from src.retrieval.metrics import RetrievalMetric
from src.retrieval.trainer import RetrievalTrainer, EarlyExitCallBack
from src.retrieval.args import RetrievalArgs, RetrievalTrainingArgs
from src.retrieval.data import RetrievalDataset, RetrievalDataCollator, SameDatasetTrainDataset, TASK_CONFIG
from src.utils.util import FileLogger, makedirs

logger = logging.getLogger(__name__)


def main():
    parser = HfArgumentParser((RetrievalArgs, RetrievalTrainingArgs))
    model_args, training_args = parser.parse_args_into_dataclasses()
    model_args: RetrievalArgs
    training_args: RetrievalTrainingArgs

    config = TASK_CONFIG[model_args.version]
    instruction = config["instruction"]

    model = DenseRetriever(
        **asdict(model_args), 
        cache_dir=model_args.model_cache_dir, 
        cos_temperature=training_args.cos_temperature,
        contrastive_weight=training_args.contrastive_weight,
        distill_weight=training_args.distill_weight,
        teacher_temperature=training_args.teacher_temperature,
        student_temperature=training_args.student_temperature,
        negative_cross_device=training_args.negative_cross_device,
        stable_distill=training_args.stable_distill,
    )
    # if model_args.train_data is not None:
    #     model.to(torch.float32)
    
    if training_args.use_train_config:
        model.train_config = config["training"]

    tokenizer = model.tokenizer

    with training_args.main_process_first():
        train_dataset, task_indices_range = RetrievalDataset.prepare_train_dataset(
            data_file=model_args.train_data, 
            cache_dir=model_args.dataset_cache_dir,
            add_instruction=model_args.add_instruction,
            train_group_size=training_args.train_group_size,
            config=config,
            use_train_config=training_args.use_train_config,
            select_positive=training_args.select_positive,
            select_negative=training_args.select_negative,
            max_sample_num=training_args.max_sample_num,
            teacher_scores_margin=training_args.teacher_scores_margin,
            teacher_scores_min=training_args.teacher_scores_min,
            stable_distill=training_args.stable_distill,
        )

        # we should get the evaluation task before specifying instruction
        if model_args.eval_data is not None and model_args.add_instruction:
            raw_eval_dataset = datasets.load_dataset('json', data_files=model_args.eval_data, split='train', cache_dir=model_args.dataset_cache_dir)
            eval_task = raw_eval_dataset[0]["task"]
        else:
            eval_task = None

        eval_dataset = RetrievalDataset.prepare_eval_dataset(
            data_file=model_args.eval_data, 
            cache_dir=model_args.dataset_cache_dir,
            instruction=instruction[eval_task] if eval_task is not None else None,
            eval_method=training_args.eval_method,
        )
        corpus = RetrievalDataset.prepare_corpus(
            data_file=model_args.corpus,
            key_template=model_args.key_template,
            cache_dir=model_args.dataset_cache_dir,
            instruction=instruction[eval_task] if eval_task is not None else None 
        )
    
    if training_args.process_index == 0:
        # NOTE: this corpus is for computing metrics, where no instruction is given
        no_instruction_corpus = RetrievalDataset.prepare_corpus(
            data_file=model_args.corpus,
            key_template=model_args.key_template,
            cache_dir=model_args.dataset_cache_dir,
        )
    else:
        no_instruction_corpus = None

    if training_args.inbatch_same_dataset is not None:
        assert training_args.dataloader_num_workers == 0, f"Make sure dataloader num_workers is 0 when using inbatch_same_dataset!"
        train_dataset = SameDatasetTrainDataset(
            train_dataset, 
            task_indices_range, 
            batch_size=training_args.per_device_train_batch_size, 
            seed=training_args.seed, 
            organize_method=training_args.inbatch_same_dataset, 
            num_processes=training_args.world_size,
            process_index=training_args.process_index,
        )
        training_args.per_device_train_batch_size = 1
    
    if training_args.early_exit_steps is not None:
        callbacks = [EarlyExitCallBack(training_args.early_exit_steps)]
    else:
        callbacks = []

    trainer = RetrievalTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        callbacks=callbacks,
        corpus=corpus,
        model_args=model_args,
        data_collator=RetrievalDataCollator(
            tokenizer=tokenizer,
            query_max_length=model_args.query_max_length,
            key_max_length=model_args.key_max_length,
            inbatch_same_dataset=training_args.inbatch_same_dataset
        ),
        compute_metrics=RetrievalMetric.get_metric_fn(
            model_args.metrics,
            # for collecting labels
            eval_data=model_args.eval_data,
            cutoffs=model_args.cutoffs,
            # for collecting positives and collating retrieval results
            save_name=model_args.save_name,
            output_dir=training_args.output_dir,
            save_to_output=model_args.save_to_output,
            # for restoring text from indices when collating results
            corpus=no_instruction_corpus,
            max_neg_num=model_args.max_neg_num,
            # for nq metrics
            cache_dir=model_args.dataset_cache_dir,
            # for collate_neg
            filter_answers=model_args.filter_answers
        ),
        file_logger=FileLogger(makedirs(training_args.log_path))
    )
    # tie accelerators
    model.accelerator = trainer.accelerator

    # Training
    if train_dataset is not None:
        trainer.train()
        return

    if eval_dataset is not None:
        trainer.evaluate()

if __name__ == "__main__":
    main()