preprocess_ragtruth.py 3.22 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Literal


@dataclass
class RagTruthSample:
    prompt: str
    answer: str
    labels: list[dict]
    split: Literal["train", "dev", "test"]
    task_type: str

    def to_json(self) -> dict:
        return {
            "prompt": self.prompt,
            "answer": self.answer,
            "labels": self.labels,
            "split": self.split,
            "task_type": self.task_type,
        }

    @classmethod
    def from_json(cls, json_dict: dict) -> "RagTruthSample":
        return cls(
            prompt=json_dict["prompt"],
            answer=json_dict["answer"],
            labels=json_dict["labels"],
            split=json_dict["split"],
            task_type=json_dict["task_type"],
        )


@dataclass
class RagTruthData:
    samples: list[RagTruthSample]

    def to_json(self) -> list[dict]:
        return [sample.to_json() for sample in self.samples]

    @classmethod
    def from_json(cls, json_dict: list[dict]) -> "RagTruthData":
        return cls(
            samples=[RagTruthSample.from_json(sample) for sample in json_dict],
        )


def load_data(input_dir: Path) -> tuple[list[dict], list[dict]]:
    """Load the RAG truth data.

    :param input_dir: Path to the input directory.
    """
    responses = [
        json.loads(line) for line in (input_dir / "response.jsonl").read_text().splitlines()
    ]
    sources = [
        json.loads(line) for line in (input_dir / "source_info.jsonl").read_text().splitlines()
    ]

    return responses, sources


def create_sample(response: dict, source: dict) -> RagTruthSample:
    """Create a sample from the RAG truth data.

    :param response: The response from the RAG truth data.
    :param source: The source from the RAG truth data.
    """
    prompt = source["prompt"]

    answer = response["response"]
    split = response["split"]
    task_type = source["task_type"]
    labels = []

    for label in response["labels"]:
        start_char = label["start"]
        end_char = label["end"]
        labels.append(
            {
                "start": start_char,
                "end": end_char,
                "label": label["label_type"],
            }
        )

    return RagTruthSample(prompt, answer, labels, split, task_type)


def main(input_dir: Path, output_dir: Path):
    """Preprocess the RAG truth data.

    :param input_dir: Path to the input directory.
    :param output_dir: Path to the output directory.
    """
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)

    responses, sources = load_data(input_dir)
    sources_by_id = {source["source_id"]: source for source in sources}

    rag_truth_data = RagTruthData(samples=[])

    for response in responses:
        sample = create_sample(response, sources_by_id[response["source_id"]])
        rag_truth_data.samples.append(sample)

    (output_dir / "ragtruth_data.json").write_text(json.dumps(rag_truth_data.to_json(), indent=4))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    args = parser.parse_args()

    main(args.input_dir, args.output_dir)