origin.py 10.7 KB
Newer Older
1
import json
2
import os
3
import random
4
5
6
import re
from pathlib import Path

7
import tiktoken
8
9
10
11
12
13
14
from datasets import Dataset

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS


15
def get_random_line_by_language(counter, file_path, language):
16
17
18
19
20
21
22
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = [
            json.loads(line.strip()) for line in file
            if json.loads(line.strip())['language'] == language
        ]

    if lines:
23
        random.seed(counter)
24
25
26
27
28
29
30
31
32
33
        random_line = random.choice(lines)
        return {
            'needle': random_line['needle'],
            'retrieval_question': random_line['retrieval_question'],
            'keyword': random_line['arg2']
        }
    else:
        return None


34
@LOAD_DATASET.register_module()
35
class NeedleBenchOriginDataset(BaseDataset):
36
37

    @staticmethod
38
39
40
41
42
43
44
45
46
47
    def load(
        path: str,
        length: int,
        depth: int,
        tokenizer_model: str,
        file_list: list[str],
        num_repeats_per_file: int,
        length_buffer: int,
        guide: bool,
        language: str,
48
        needle_file_name: str,
49
        position: str = 'End',
50
    ):
51
        data = {'prompt': [], 'answer': []}
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        tokenizer = tiktoken.encoding_for_model(tokenizer_model)

        def _generate_context(tokens_context, depth_percent, needle):
            tokens_needle = _get_tokens_from_context(needle)
            insertion_point = int(len(tokens_context) * (depth_percent / 100))
            tokens_context = (tokens_context[:insertion_point] +
                              tokens_needle + tokens_context[insertion_point:])
            new_context = _decode_tokens(tokens_context)
            return new_context

        def _get_tokens_from_context(context):
            return tokenizer.encode(context)

        def _decode_tokens(tokens):
            return tokenizer.decode(tokens)

        def _modify_retrieval_question(retrieval_question):
            if language == 'Chinese':
                parts = retrieval_question.split('请按照')
                guide_retrieval_question = (parts[0] + '在回答之前,请思考文档中与此问题'
                                            '最相关的内容是什么。请按照' + parts[1])
                return guide_retrieval_question
            elif language == 'English':
                parts = retrieval_question.split('Please answer in the format')
                guide_retrieval_question = (
                    parts[0] + 'Before answering, please consider'
                    ' what in the document is most relevant to this question.'
                    ' Please answer in the format' + parts[1])
                return guide_retrieval_question
            else:
                raise ValueError(f"Language '{language}' is not supported.")

        def _generate_prompt(context, retrieval_question):
            if guide:
                retrieval_question = _modify_retrieval_question(
                    retrieval_question)

            if language == 'Chinese':
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
                if position == 'End':
                    prompt = ('你是一个善于回答用户问题的智能AI助手\n'
                              '请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
                              ',或重复你的回答\n'
                              f'用户现在给你的文档是{context}\n\n'
                              f'现在请问:{retrieval_question}')
                elif position == 'Start':
                    prompt = ('你是一个善于回答用户问题的智能AI助手\n'
                              '请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
                              ',或重复你的回答\n'
                              f'现在请问:{retrieval_question}',
                              f'用户现在给你的文档是{context}\n\n')
                else:
                    raise ValueError('Unsupported position. '
                                     'Position must be "End" or "Start".')
105
            elif language == 'English':
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
                if position == 'End':
                    prompt = ('You are an intelligent AI assistant skilled in '
                              'answering user questions.\n'
                              'Please keep your answers concise and clear. Do '
                              'not talk about irrelevant topics or repeat '
                              'your answers.\nThe document '
                              f'given to you by the user is {context}\n\n'
                              f'Now, the question is: {retrieval_question}')
                elif position == 'Start':
                    prompt = ('You are an intelligent AI assistant skilled in '
                              'answering user questions.\n'
                              'Please keep your answers concise and clear. Do '
                              'not talk about irrelevant topics or repeat '
                              'your answers.\n'
                              f'Now, the question is: {retrieval_question}'
                              'The document given to you by the user'
                              f' is {context}\n\n')
                else:
124
                    raise ValueError(f'Unsupported position {position}. '
125
                                     'Position must be "End" or "Start".')
126
127
128
129
130
131
132
133
134
135
            else:
                raise ValueError(f"Language '{language}' is not supported.")

            return prompt

        files = Path(path).glob('*.jsonl')
        for file in files:
            if file.name not in file_list:
                continue

136
            with open(file, 'r', encoding='utf-8') as f:
137
138
139
140
141
                lines_bak = [json.loads(line.strip()) for line in f]
            lines = lines_bak.copy()
            for counter in range(num_repeats_per_file):
                random.seed(counter)
                random.shuffle(lines)
142
143
                needle_file_path = os.path.join(path, needle_file_name)
                random_needle = get_random_line_by_language(
144
                    counter, needle_file_path, language)
145
146
147
                needle = '\n' + random_needle['needle'] + '\n'
                retrieval_question = random_needle['retrieval_question']
                keyword = random_needle['keyword']
148
149
150
151

                context_length = length - length_buffer
                target_length_per_record = context_length - len(
                    _get_tokens_from_context(needle))
152
                target_length_per_record = max(target_length_per_record, 0)
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
                accumulated_tokens = []
                for line in lines:
                    tokens_current_line = _get_tokens_from_context(
                        line['text'])
                    accumulated_tokens.extend(tokens_current_line)

                    if len(accumulated_tokens) >= target_length_per_record:
                        break

                processed_text = _generate_context(
                    accumulated_tokens[:target_length_per_record], depth,
                    needle)

                processed_prompt = _generate_prompt(processed_text,
                                                    retrieval_question)

                data['prompt'].append(processed_prompt)
170
                data['answer'].append(needle + '*' + keyword)
171
172
173
174
175
176
177
178

        dataset = Dataset.from_dict({
            'prompt': data['prompt'],
            'answer': data['answer'],
        })
        return dataset


179
class NeedleBenchOriginEvaluator(BaseEvaluator):
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def __init__(self, use_trim=False):
        self.use_trim = use_trim

    @staticmethod
    def _trim_prediction(prediction, reference):
        """Trims the prediction string based on the length of the reference
        string.

        Args:
            prediction (str): The prediction string.
            reference (str): The reference string.

        Returns:
            str: The trimmed prediction string.
        """
        l08 = int(0.8 * len(reference))
        l12 = int(1.2 * len(reference))
        trimmed_prediction = prediction[:l12]

        if len(trimmed_prediction) > l08 and \
                reference[-1] in trimmed_prediction[l08:]:
            end_pos = l08 + trimmed_prediction[l08:].index(reference[-1]) + 1
            trimmed_prediction = trimmed_prediction[:end_pos]

        return trimmed_prediction

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    def levenshtein_distance(self, s1, s2):
        if len(s1) < len(s2):
            return self.levenshtein_distance(s2, s1)

        if len(s2) == 0:
            return len(s1)

        previous_row = range(len(s2) + 1)
        for i, c1 in enumerate(s1):
            current_row = [i + 1]
            for j, c2 in enumerate(s2):
                insertions = previous_row[j + 1] + 1
                deletions = current_row[j] + 1
                substitutions = previous_row[j] + (c1 != c2)
                current_row.append(min(insertions, deletions, substitutions))
            previous_row = current_row

        return previous_row[-1]

226
227
228
229
    def score(self, predictions, gold):

        if len(predictions) != len(gold):
            return {'error': 'predictions and gold have different lengths'}
230
231
232

        total_score = 0
        details = []
233
234
235
236
        for prediction, reference in zip(predictions, gold):
            keyword = reference.split('*')[1]
            reference = reference.split('*')[0]
            raw_prediction = prediction
237
238
            prediction = re.sub(r'\s+', '', prediction)
            reference = re.sub(r'\s+', '', reference)
239
240

            if self.use_trim:
241
                prediction = NeedleBenchOriginEvaluator._trim_prediction(
242
243
                    prediction, reference)

244
245
246
247
248
            edit_distance = self.levenshtein_distance(prediction, reference)
            max_len = max(len(prediction), len(reference))
            score = 100 * (1 -
                           edit_distance / max_len) if max_len != 0 else 100

249
250
251
252
253
254
255
            if keyword in raw_prediction:
                print(f'{keyword} is in {prediction}')
                score = 100
            else:
                print(f'{keyword} is not in {prediction}')
                score = 0.2 * score

256
257
258
259
260
261
262
263
264
265
266
267
268
269
            detail = {
                'pred': prediction,
                'answer': reference,
                'edit_distance': edit_distance,
                'score': score
            }
            total_score += score
            details.append(detail)

        average_score = total_score / len(predictions) if predictions else 0
        result = {'score': average_score, 'details': details}
        return result


270
271
@TEXT_POSTPROCESSORS.register_module('needlebench')
def needlebench_postprocess(text: str) -> str:
272
273
274
    return text


275
276
@TEXT_POSTPROCESSORS.register_module('needlebench_dataset')
def needlebench_dataset_postprocess(text: str) -> str:
277
    return text