needlebench.py 16.4 KB
Newer Older
1
2
3
from opencompass.summarizers.needlebench import NeedleBenchSummarizer


4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def create_m_rs_names_list(context_lengths, depths, needle_counts,
                           languages, dataset_size):
    names_dict = {}
    multi_needle_list = []
    multi_needle_en_list = []
    multi_needle_zh_list = []

    for needle_count in needle_counts:
        for language in languages:
            key = f"{needle_count}-Needle-{language.upper()}-{dataset_size.upper()}"
            names_list = [
                f"Length{length}Depth{int(depth)}_{needle_count}needle_{language}_{dataset_size}"
                for length in context_lengths
                for depth in depths
            ]
            names_dict[key] = names_list
            
            multi_needle_list.extend(names_list)
            if language == 'en':
                multi_needle_en_list.extend(names_list)
            elif language == 'zh':
                multi_needle_zh_list.extend(names_list)
    names_dict['Multi-Needle-Reasoning(M-RS)'] =  multi_needle_list
    names_dict['Multi-Needle-Reasoning-EN'] = multi_needle_en_list
    names_dict['Multi-Needle-Reasoning-ZH'] = multi_needle_zh_list

    return names_dict

def create_summarizer(context_lengths, depths, dataset_size, 
                      sparse_depths=None):
    needle_counts = ["2", "3", "4", "5"]
    languages = ["en", "zh"]
    if sparse_depths:
        depths = sparse_depths
    names_dict = {}
    multi_reasoning_names = create_m_rs_names_list(
        context_lengths, depths, needle_counts, languages, dataset_size)

    names_dict.update(multi_reasoning_names)
    
    single_needle_list = []
    single_needle_en_list = []
    single_needle_zh_list = []

    for language in languages:
        names_list = [
            f"Length{length}Depth{int(depth)}_origin_{language}_{dataset_size}"
            for length in context_lengths
            for depth in depths
        ]
        single_needle_list.extend(names_list)
        if language == 'en':
            single_needle_en_list.extend(names_list)
        elif language == 'zh':
            single_needle_zh_list.extend(names_list)
    names_dict['Single-Needle-Retrieval(S-RT)'] = single_needle_list
    names_dict['Single-Needle-Retrieval-EN'] = single_needle_en_list
    names_dict['Single-Needle-Retrieval-ZH'] = single_needle_zh_list

    parallel_list = []
    parallel_en_list = []
    parallel_zh_list = []

    for language in languages:
        names_list = [
            f"Length{length}_parallel_{language}_{dataset_size}"
            for length in context_lengths
        ]
        parallel_list.extend(names_list)
        if language == 'en':
            parallel_en_list.extend(names_list)
        elif language == 'zh':
            parallel_zh_list.extend(names_list)
    names_dict['Multi-Needle-Retrieval(M-RT)'] = parallel_list
    names_dict['Multi-Needle-Retrieval-EN'] = parallel_en_list
    names_dict['Multi-Needle-Retrieval-ZH'] = parallel_zh_list

    summary_groups = [
        {'name': key, 'subsets': value} for key, value in names_dict.items()
    ]

    summary_groups.append({
        'name': 'NeedleBench-Overall-Score',
        'subsets': [['Single-Needle-Retrieval(S-RT)', 'naive_average'],
                    ['Multi-Needle-Reasoning(M-RS)', 'naive_average'],
                    ['Multi-Needle-Retrieval(M-RT)', 'average_score']],
        'weights': {'Single-Needle-Retrieval(S-RT)': 0.4,
                    'Multi-Needle-Reasoning(M-RS)': 0.3,
                    'Multi-Needle-Retrieval(M-RT)': 0.3}})
    summarizer_config = {
        'type': NeedleBenchSummarizer,
        'summary_groups': summary_groups,
        'dataset_abbrs': [
            'NeedleBench-Overall-Score',
            f'--------- NeedleBench-{dataset_size.upper()}-Single-Needle-Retrieval ---------',
            'Single-Needle-Retrieval(S-RT)',
            'Single-Needle-Retrieval-EN',
            'Single-Needle-Retrieval-ZH',
            f'--------- NeedleBench-{dataset_size.upper()}-Multi-Needle-Retrieval ---------',
            'Multi-Needle-Retrieval(M-RT)',
            'Multi-Needle-Retrieval-EN',
            'Multi-Needle-Retrieval-ZH',
            f'--------- NeedleBench-{dataset_size.upper()}-Multi-Needle-Reasoning ---------',
            'Multi-Needle-Reasoning(M-RS)',
            'Multi-Needle-Reasoning-EN',
            'Multi-Needle-Reasoning-ZH',
            '2-Needle-EN-4K',
            '2-Needle-ZH-4K',
            '3-Needle-EN-4K',
            '3-Needle-ZH-4K',
            '4-Needle-EN-4K',
            '4-Needle-ZH-4K',
            '5-Needle-EN-4K',
            '5-Needle-ZH-4K',
            ]
        }
    return summarizer_config
121
122


123
124
depths = [0, 5, 10, 15, 21, 26, 31, 36, 42, 47, 52, 57, 63, 68, 73, 78, 84, 89, 94, 100]
depths_list_sparse = [0, 10, 21, 31, 42, 52, 63, 73, 84, 94, 100]
125

126
127
context_lengths_4k = list(range(1000, 5000, 1000))
needlebench_4k_summarizer = create_summarizer(context_lengths_4k, depths, "4k")
128
context_lengths_8k = list(range(5000, 9000, 1000))
129
needlebench_8k_summarizer = create_summarizer(context_lengths_8k, depths, "8k")
130
context_lengths_32k = [9000, 13000, 17000, 21000, 25000, 29000, 31000, 32000]
131
needlebench_32k_summarizer = create_summarizer(context_lengths_32k, depths_list_sparse, "32k")
132
context_lengths_128k = list([16000, 32000, 48000, 64000, 80000, 96000, 112000, 128000])
133
needlebench_128k_summarizer = create_summarizer(context_lengths_128k, depths_list_sparse, "128k")
134
context_lengths_200k = list([16000, 48000, 80000, 112000, 128000, 144000, 176000, 200000])
135
needlebench_200k_summarizer = create_summarizer(context_lengths_200k, depths_list_sparse, "200k")
136
context_lengths_1000k = list([20000, 160000, 300000, 440000, 580000, 720000, 860000, 1000000])
137
needlebench_1000k_summarizer = create_summarizer(context_lengths_1000k, depths_list_sparse, "1000k")
138
139


140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
_needlebench_8k_parallel_en_batch1 = []
_needlebench_8k_parallel_en_batch5 = []
_needlebench_8k_parallel_en_batch10 = []
_needlebench_8k_parallel_en_batch15 = []
_needlebench_8k_parallel_en_batch20 = []
_needlebench_8k_parallel_zh_batch1 = []
_needlebench_8k_parallel_zh_batch5 = []
_needlebench_8k_parallel_zh_batch10 = []
_needlebench_8k_parallel_zh_batch15 = []
_needlebench_8k_parallel_zh_batch20 = []
for original_context_length in context_lengths_8k:
    _needlebench_8k_parallel_en_batch1.append(f'Length{original_context_length}_parallel_en_8k_batch1')
    _needlebench_8k_parallel_en_batch5.append(f'Length{original_context_length}_parallel_en_8k_batch5')
    _needlebench_8k_parallel_en_batch10.append(f'Length{original_context_length}_parallel_en_8k_batch10')
    _needlebench_8k_parallel_en_batch15.append(f'Length{original_context_length}_parallel_en_8k_batch15')
    _needlebench_8k_parallel_en_batch20.append(f'Length{original_context_length}_parallel_en_8k_batch20')
    _needlebench_8k_parallel_zh_batch1.append(f'Length{original_context_length}_parallel_zh_8k_batch1')
    _needlebench_8k_parallel_zh_batch5.append(f'Length{original_context_length}_parallel_zh_8k_batch5')
    _needlebench_8k_parallel_zh_batch10.append(f'Length{original_context_length}_parallel_zh_8k_batch10')
    _needlebench_8k_parallel_zh_batch15.append(f'Length{original_context_length}_parallel_zh_8k_batch15')
    _needlebench_8k_parallel_zh_batch20.append(f'Length{original_context_length}_parallel_zh_8k_batch20')


_needlebench_8k_parallel_batch1 = _needlebench_8k_parallel_en_batch1 + _needlebench_8k_parallel_zh_batch1
_needlebench_8k_parallel_batch5 = _needlebench_8k_parallel_en_batch5 + _needlebench_8k_parallel_zh_batch5
_needlebench_8k_parallel_batch10 = _needlebench_8k_parallel_en_batch10 + _needlebench_8k_parallel_zh_batch10
_needlebench_8k_parallel_batch15 = _needlebench_8k_parallel_en_batch15 + _needlebench_8k_parallel_zh_batch15
_needlebench_8k_parallel_batch20 = _needlebench_8k_parallel_en_batch20 + _needlebench_8k_parallel_zh_batch20

needlebench_summary_groups = [
    {'name': 'parallel_version_batch1', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_batch1]},
    {'name': 'parallel_version_zh_batch1', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_zh_batch1]},
    {'name': 'parallel_version_en_batch1', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_en_batch1]},
    {'name': 'parallel_version_batch5', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_batch5]},
    {'name': 'parallel_version_zh_batch5', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_zh_batch5]},
    {'name': 'parallel_version_en_batch5', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_en_batch5]},
    {'name': 'parallel_version_batch10', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_batch10]},
    {'name': 'parallel_version_zh_batch10', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_zh_batch10]},
    {'name': 'parallel_version_en_batch10', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_en_batch10]},
    {'name': 'parallel_version_batch15', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_batch15]},
    {'name': 'parallel_version_zh_batch15', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_zh_batch15]},
    {'name': 'parallel_version_en_batch15', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_en_batch15]},
    {'name': 'parallel_version_batch20', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_batch20]},
    {'name': 'parallel_version_zh_batch20', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_zh_batch20]},
    {'name': 'parallel_version_en_batch20', 'subsets': [[_dataset, "average_score"] for _dataset in _needlebench_8k_parallel_en_batch20]},
]

needlebench_8k_batch_overall_summarizer = dict(
    dataset_abbrs=[
        '--------- NeedleBench-8k Parallel-Needles ---------',  # category
        'parallel_version_batch1',
        'parallel_version_batch5',
        'parallel_version_batch10',
        'parallel_version_batch15',
        'parallel_version_batch20',
        'parallel_version_zh_batch1',
        'parallel_version_en_batch1',
        'parallel_version_zh_batch5',
        'parallel_version_en_batch5',
        'parallel_version_zh_batch10',
        'parallel_version_en_batch10',
        'parallel_version_zh_batch15',
        'parallel_version_en_batch15',
        'parallel_version_zh_batch20',
        'parallel_version_en_batch20',
    ],
    summary_groups=needlebench_summary_groups,
)

needlebench_summary_groups = [
    {'name': 'parallel_version_batch1', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_batch1]},
    {'name': 'parallel_version_zh_batch1', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_zh_batch1]},
    {'name': 'parallel_version_en_batch1', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_en_batch1]},
    {'name': 'parallel_version_batch5', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_batch5]},
    {'name': 'parallel_version_zh_batch5', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_zh_batch5]},
    {'name': 'parallel_version_en_batch5', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_en_batch5]},
    {'name': 'parallel_version_batch10', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_batch10]},
    {'name': 'parallel_version_zh_batch10', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_zh_batch10]},
    {'name': 'parallel_version_en_batch10', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_en_batch10]},
    {'name': 'parallel_version_batch15', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_batch15]},
    {'name': 'parallel_version_zh_batch15', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_zh_batch15]},
    {'name': 'parallel_version_en_batch15', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_en_batch15]},
    {'name': 'parallel_version_batch20', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_batch20]},
    {'name': 'parallel_version_zh_batch20', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_zh_batch20]},
    {'name': 'parallel_version_en_batch20', 'subsets': [[_dataset, "Depth0"] for _dataset in _needlebench_8k_parallel_en_batch20]},
]

needlebench_8k_batch_depth0_summarizer = dict(
    dataset_abbrs=[
        '--------- NeedleBench-8k Parallel-Needles ---------',  # category
        'parallel_version_batch1',
        'parallel_version_batch5',
        'parallel_version_batch10',
        'parallel_version_batch15',
        'parallel_version_batch20',
        'parallel_version_zh_batch1',
        'parallel_version_en_batch1',
        'parallel_version_zh_batch5',
        'parallel_version_en_batch5',
        'parallel_version_zh_batch10',
        'parallel_version_en_batch10',
        'parallel_version_zh_batch15',
        'parallel_version_en_batch15',
        'parallel_version_zh_batch20',
        'parallel_version_en_batch20',
    ],
    summary_groups=needlebench_summary_groups,
)
248

249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def gen_atc_summarizer(needle_num_list):
    categories = [
        'ZH-Direct-CE', 'EN-Direct-CE',
        'ZH-Reasoning-CE', 'EN-Reasoning-CE'
    ]
    needlebench_atc_summary_groups = []

    # 根据分类生成summary groups
    for category in categories:
        # 对于CircularEval相关的评分,使用perf_4指标,否则使用acc_1指标
        metric = 'perf_4' if 'CE' in category else 'acc_1'
        # 生成subsets时,不需要在数据集名称中包含CircularEval信息
        cleaned_category = category.replace('-CE', '').replace('-Direct', '')
        needlebench_atc_summary_groups.append({
            'name': category,
            'subsets': [
                [f'NeedleBenchATCDataset-{num_needles}Needle-{cleaned_category}', metric]
                for num_needles in needle_num_list
            ],
            'weights': {f'NeedleBenchATCDataset-{num_needles}Needle-{cleaned_category}': num_needles for num_needles in needle_num_list},
        })
270
271

    needlebench_atc_summary_groups.append({
272
        'name': 'ATC-CE-Overall',
273
        'subsets': [
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
            [f'{category}', 'weighted_average'] for category in categories
        ],
    })
    atc_dataset_abbrs = []
    atc_dataset_abbrs.append(['ATC-CE-Overall', 'naive_average'])

    for category in categories:
        weighted_average_score_entry = [f'{category}', 'weighted_average']
        atc_dataset_abbrs.append(weighted_average_score_entry)

    needlebench_atc_summarizer = dict(
        dataset_abbrs=[
            *atc_dataset_abbrs,
            '######## Needlebench-ATC Accuracy ########',  # category
            *[[f'NeedleBenchATCDataset-{num_needles}Needle-ZH', 'acc_1'] for num_needles in needle_num_list],
            '------------------------------------------',
            *[[f'NeedleBenchATCDataset-{num_needles}Needle-EN', 'acc_1'] for num_needles in needle_num_list],
            '------------------------------------------',
            *[[f'NeedleBenchATCDataset-{num_needles}Needle-ZH-Reasoning', 'acc_1'] for num_needles in needle_num_list],
            '------------------------------------------',
            *[[f'NeedleBenchATCDataset-{num_needles}Needle-EN-Reasoning', 'acc_1'] for num_needles in needle_num_list],
            '------------------------------------------',
            '######## Needlebench-ATC CircularEval ########',  # category
            *[[f'NeedleBenchATCDataset-{num_needles}Needle-ZH', 'perf_4'] for num_needles in needle_num_list],
            '------------------------------------------',
            *[[f'NeedleBenchATCDataset-{num_needles}Needle-EN', 'perf_4'] for num_needles in needle_num_list],
            '------------------------------------------',
            *[[f'NeedleBenchATCDataset-{num_needles}Needle-ZH-Reasoning', 'perf_4'] for num_needles in needle_num_list],
            '------------------------------------------',
            *[[f'NeedleBenchATCDataset-{num_needles}Needle-EN-Reasoning', 'perf_4'] for num_needles in needle_num_list],
            '------------------------------------------',
        ],
        summary_groups=needlebench_atc_summary_groups
    )
    return needlebench_atc_summarizer


atc_summarizer_20 = gen_atc_summarizer(list(range(2, 20, 1)))
atc_summarizer_50 = gen_atc_summarizer(list(range(2, 50, 1)))
atc_summarizer_80 = gen_atc_summarizer(list(range(2, 80, 1)))