judge_task.py 1.96 KB
Newer Older
Baber's avatar
Baber committed
1
import json
Baber's avatar
Baber committed
2
from typing import Any, Dict, Optional
Baber's avatar
Baber committed
3
4
5
6
7
8
9

import datasets

from lm_eval.api.task import ConfigurableTask


class JudgeTask(ConfigurableTask):
Baber's avatar
Baber committed
10
11
12
13
14
15
16
17
18
19
20
    def __init__(
        self,
        data_dir=None,
        cache_dir=None,
        download_mode=None,
        config: Optional[dict] = None,
        output_path: Optional[str] = None,
        **kwargs,
    ) -> None:
        self._config = config
        # self.config["process_docs"] = self.process_docs
Baber's avatar
Baber committed
21
        self.output_path = output_path
Baber's avatar
Baber committed
22
23
24
25
26
27
28
29
30
31
32
        super().__init__(
            data_dir=None, cache_dir=None, download_mode=None, config=self.config
        )
        # self.output_path = output_path

    def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            **dataset_kwargs if dataset_kwargs is not None else {},
        )
Baber's avatar
Baber committed
33
34
35
36

        resps = []
        # load json
        with open(self.output_path, "r") as f:
Baber's avatar
Baber committed
37
38
39
            for line in f:
                resp = json.loads(line)
                resps.append({"resp": resp["resps"][0][0], "doc": resp["doc_id"]})
Baber's avatar
Baber committed
40
41

        resps.sort(key=lambda x: x["doc"])
Baber's avatar
Baber committed
42
43
44
45
46
        # TODO: add filter name to resps
        resps = resps[::2]
        self.dataset["test"] = self.dataset["test"].add_column(
            "resp", [resp["resp"] for resp in resps]
        )
Baber's avatar
Baber committed
47
48
49
50
        self.dataset["train"] = self.dataset["train"].add_column(
            "resp", self.dataset["train"]["answer"]
        )
        print("resp columns added")
Baber's avatar
Baber committed
51
52
53
54
55
56
57
58
59
60
61
62

    # def process_docs(self, dataset: datasets.Dataset):
    #     resps = []
    #     # load json
    #     with open(self.output_path, "r") as f:
    #         for line in f:
    #             resp = json.loads(line)
    #             resps.append({"resp": resp["resps"][0][0], "doc": resp["doc_id"]})
    #
    #     resps.sort(key=lambda x: x["doc"])
    #     dataset.add_column("resp", resps)
    #     return resps