"include/LightGBM/vscode:/vscode.git/clone" did not exist on "b6d4ad8361379a902091e3a584491c39fa89ef2d"
eval_mmlu.py 12.8 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import os
import copy
import json
import logging
import datasets
from typing import List
from accelerate import Accelerator
from torch.utils.data import DataLoader
from transformers import HfArgumentParser
from dataclasses import dataclass, field, asdict
from collections import defaultdict

from src.lm import (
    LM, 
    LMArgs
)
from src.retrieval import (
    RetrievalArgs, 
    RetrievalMetric,
)
from src.utils.util import makedirs, remove_eos, DefaultDataCollator, DatasetProcessFn, FileLogger
from .eval_retrieval import main as retrieval_main

logger = logging.getLogger(__name__)
import transformers
transformers.logging.set_verbosity_error()

SUBJECT_2_CATEGORY={"abstract_algebra": "STEM", "anatomy": "others", "astronomy": "STEM", "business_ethics": "others", "clinical_knowledge": "others", "college_biology": "STEM", "college_chemistry": "STEM", "college_computer_science": "STEM", "college_mathematics": "STEM", "college_medicine": "others", "college_physics": "STEM", "computer_security": "STEM", "conceptual_physics": "STEM", "econometrics": "Social Sciences", "electrical_engineering": "STEM", "elementary_mathematics": "STEM", "formal_logic": "Humanities", "global_facts": "others", "high_school_biology": "STEM", "high_school_chemistry": "STEM", "high_school_computer_science": "STEM", "high_school_european_history": "Humanities", "high_school_geography": "Social Sciences", "high_school_government_and_politics": "Social Sciences", "high_school_macroeconomics": "Social Sciences", "high_school_mathematics": "STEM", "high_school_microeconomics": "Social Sciences", "high_school_physics": "STEM", "high_school_psychology": "Social Sciences", "high_school_statistics": "STEM", "high_school_us_history": "Humanities", "high_school_world_history": "Humanities", "human_aging": "others", "human_sexuality": "Social Sciences", "international_law": "Humanities", "jurisprudence": "Humanities", "logical_fallacies": "Humanities", "machine_learning": "STEM", "management": "others", "marketing": "others", "medical_genetics": "others", "miscellaneous": "others", "moral_disputes": "Humanities", "moral_scenarios": "Humanities", "nutrition": "others", "philosophy": "Humanities", "prehistory": "Humanities", "professional_accounting": "others", "professional_law": "Humanities", "professional_medicine": "others", "professional_psychology": "Social Sciences", "public_relations": "Social Sciences", "security_studies": "Social Sciences", "sociology": "Social Sciences", "us_foreign_policy": "Social Sciences", "virology": "others", "world_religions": "Humanities"}


@dataclass
class MMLUArgs(LMArgs, RetrievalArgs):
    output_dir: str = field(
        default="data/results/mmlu",
    )
    eval_data: str = field(
        default="llm-embedder:qa/mmlu/test.json",
        metadata={'help': 'Path to the test file.'}
    )
    lm_batch_size: int = field(
        default=2,
        metadata={'help': 'Evaluation batch size.'},
    )

    few_shot: int = field(
        default=0,
        metadata={'help': 'How many few shot train samples?'},
    )
    train_data: str = field(
        default="llm-embedder:qa/mmlu/dev.json",
        metadata={'help': 'Path to the file containing training examples.'}
    )

    corpus: str = field(
        default="llm-embedder:qa/msmarco/corpus.json",
        metadata={'help': 'Corpus path for retrieval.'}
    )
    key_template: str = field(
        default="{title} {text}",
        metadata={'help': 'How to concatenate columns in the corpus to form one key?'}
    )
    key_max_length: int = field(
        default=128,
        metadata={'help': 'How many tokens at maximum in a key.'}
    )
    hits: int = field(
        default=10,
        metadata={'help': 'How many hits per query?'},
    )
    key_num: int = field(
        default=3,
        metadata={'help': 'How many docs to provide in prompt?'},
    )
    metrics: List[str] = field(
        default_factory=lambda: ["collate_key"],
    )
    save_to_output: bool = field(
        default=True,
        metadata={'help': 'Save the result/key/negative to output_dir? If not true, they will be saved next to the eval_data.'}
    )

    log_path: str = field(
        default="data/results/mmlu/mmlu.log",
        metadata={'help': 'Path to the file for logging.'}
    )
    

def process_mmlu(tokenizer, context_max_length=2048, key_num=3, few_shot=0, train_data=None, cache_dir=None, is_encoder_decoder=False, add_llama_inst=False):
    tokenizer.truncation_side = 'right'
    left_truncation_tokenizer = copy.deepcopy(tokenizer)
    left_truncation_tokenizer.truncation_side = 'left'

    test = tokenizer("test", return_special_tokens_mask=True)["special_tokens_mask"]

    has_bos = has_eos = False
    if test[0] == 1:
        has_bos = True
    if test[-1] == 1:
        has_eos = True
    
    if few_shot > 0:
        assert train_data is not None
        train_data = datasets.load_dataset("json", data_files=train_data, cache_dir=cache_dir, split="train")
        train_df = train_data.to_pandas()
        # transform the dataframe into dict of dataframes
        train_df = {k: v[:few_shot] for k, v in train_df.groupby("subject")}
        
    options = ['A', 'B', 'C', 'D']
    
    def _prepare_sample(query, choices, answer):
        """
        <Question>
        A. <Choices 1>
        B. <Choices 2>
        C. <Choices 3>
        D. <Choices 4>
        Answer: <Answer>
        """
        # answer maybe int or numpy int64
        if not isinstance(answer, str):
            answer = options[answer]

        sample = f"{query}\n{chr(10).join([f'{option}. {choice}' for option, choice in zip(options, choices)])}\nAnswer: {answer}"
        return sample
    
    def _prepare_knowledge(key, max_length=None):
        if key is not None:
            key = key[:key_num]
            key = "\n".join(key)
            key = f"Knowledge:\n{key}"
            if max_length is not None:
                # truncate key if necessary
                key = tokenizer.decode(tokenizer.encode(key, add_special_tokens=False, truncation=True, max_length=max_length))
        else:
            key = ""
        return key

    @DatasetProcessFn(augment=True)
    def _process(query, choices, query_id, subject, answer, key=None, **kwds):
        """Yield key and query with a prompt template"""
        output = defaultdict(list)
        query = query.strip()

        head = f"The following are multiple choice questions (with answers) about {' '.join(subject.split('_'))}.\n\n"

        if few_shot > 0:
            train_samples = ""
            for i in range(few_shot):
                if i >= len(train_df[subject]):
                    break
                train_sample = train_df[subject].iloc[i][['query', 'choices', 'answer']]
                train_sample = _prepare_sample(**train_sample) + "\n\n"
                train_samples += train_sample
        else:
            train_samples = ""

        knowledge_max_length = context_max_length - len(tokenizer.encode(head + train_samples + _prepare_sample(query, choices, 'A'))) - int(has_bos) - int(has_eos)
        if knowledge_max_length < 0:
            knowledge = ""
        else:
            knowledge = _prepare_knowledge(key, knowledge_max_length)

        for option in options:
            left = knowledge
            right = head + train_samples + _prepare_sample(query, choices, option)
            # \n\n to split knowledge and prompts
            if len(left):
                right = "\n\n" + right

            # TODO: add llama instruction
            # if add_llama_inst:
            #     left = "[INST]" + left
            #     right = right + "[/INST]"

            inputs = left_truncation_tokenizer(left + right, truncation=True, max_length=context_max_length, return_token_type_ids=False)

            if has_eos and not is_encoder_decoder:
                inputs = remove_eos(inputs, tokenizer.eos_token_id)

            # find answer length
            option_seq = tokenizer.encode("Answer: " + option, add_special_tokens=False)            
            option_length = len(option_seq) - len(tokenizer.encode("Answer:", add_special_tokens=False))

            if is_encoder_decoder:
                labels = inputs["input_ids"].copy()[-option_length:]
                for k, v in inputs.items():
                    inputs[k] = v[:-option_length]
                inputs["labels"] = labels

            else:
                # take care of padded tokens
                labels = inputs["input_ids"].copy()
                labels = [x if inputs["attention_mask"][i] == 1 else -100 for i, x in enumerate(labels)]
                labels[:-option_length] = [-100] * (len(labels) - option_length)
                inputs["labels"] = labels

            inputs["query_id"] = query_id
            for k, v in inputs.items():
                output[k].append(v)
        return output
    return _process


def evaluate_mmlu(eval_data, save_path, **kwds):
    def compute_metric(eval_preds):
        makedirs(save_path)

        tasks = defaultdict(list)
        results = defaultdict(list)
        samples = {}
        
        with open(eval_data) as f:
            for line in f:
                sample = json.loads(line.strip())
                samples[sample["query_id"]] = sample
        
        # nll must comes in the order of A, B, C, and D
        for query_id, nll in zip(*eval_preds):
            # store log likelihood
            results[query_id].append(-nll)
        
        with open(makedirs(save_path), "w") as f:
            for k, v in results.items():
                output = max(enumerate(v), key=lambda x: x[1])[0]
                sample = samples[k]
                sample["output"] = output
                tasks[sample["subject"]].append((output, sample["answer"]))
                f.write(json.dumps(sample, ensure_ascii=False) + "\n")

        metrics = defaultdict(list)
        for task_name, task_eval_preds in tasks.items():
            accuracy = 0
            for pred, label in task_eval_preds:
                accuracy += int(pred == label)
            accuracy /= len(task_eval_preds)

            category = SUBJECT_2_CATEGORY[task_name]
            metrics[f"{category}"].append(accuracy)
            metrics["all"].append(accuracy)
        
        for k, v in metrics.items():
            metrics[k] = sum(v) / len(v)
        
        metrics = {
            "STEM": metrics["STEM"],
            "Social Sciences": metrics["Social Sciences"],
            "Humanities": metrics["Humanities"],
            "Others": metrics["others"],
            "All": metrics["all"],
        }

        return dict(metrics)
    return compute_metric


def main():
    parser = HfArgumentParser([MMLUArgs])
    args, = parser.parse_args_into_dataclasses()

    accelerator = Accelerator(cpu=args.cpu)

    # modify the output_dir for retrieval
    if args.retrieval_method == "dense":
        output_dir = os.path.join(args.output_dir, args.query_encoder.strip(os.sep).replace(os.sep, "--"))
    else:
        output_dir = os.path.join(args.output_dir, args.retrieval_method)
    args.output_dir = output_dir

    if args.retrieval_method != "no":
        retrieval_main(args=args, accelerator=accelerator, log=False)
        eval_data = RetrievalMetric._get_save_path(args.eval_data, args.output_dir, field="key", save_name=args.save_name)
    else:
        eval_data = args.eval_data

    lm = LM(
        model_name_or_path=args.model_name_or_path,
        dtype=args.lm_dtype,
        device_map=args.lm_device_map,
        padding_side=args.padding_side,
        cache_dir=args.model_cache_dir,
        accelerator=accelerator
    )
    
    tokenizer = lm.tokenizer

    with accelerator.main_process_first():
        logging.info(f"Loading data from {eval_data}...")
        dataset = datasets.load_dataset("json", data_files=eval_data, split="train", cache_dir=args.dataset_cache_dir)
        dataset = dataset.map(process_mmlu(
            tokenizer, 
            context_max_length=args.context_max_length, 
            key_num=args.key_num,
            few_shot=args.few_shot,
            train_data=args.train_data,
            cache_dir=args.dataset_cache_dir,
            is_encoder_decoder=lm.model.config.is_encoder_decoder,
            add_llama_inst=args.add_llama_inst
        ), remove_columns=dataset.column_names, batched=True, num_proc=32)

    data_collator = DefaultDataCollator(tokenizer=tokenizer, add_position_ids=args.add_position_ids)
    dataloader = DataLoader(
        dataset, 
        batch_size=args.lm_batch_size, 
        collate_fn=data_collator,
        pin_memory=True,
    )
    dataloader = accelerator.prepare(dataloader)

    results = lm.compute_nlls(dataloader)

    if accelerator.process_index == 0:
        file_logger = FileLogger(makedirs(args.log_path))    
        result_path = os.path.join(args.output_dir, args.model_name_or_path.strip(os.sep).replace(os.sep, "--") + ".json")
        metrics = evaluate_mmlu(eval_data, result_path)(results)
        file_logger.log(metrics, Args=asdict(args))


if __name__ == "__main__":
    main()