Commit 7d346000 authored by gaotongxiao's avatar gaotongxiao
Browse files

initial commit

parents
# 概述
# 安装
1. 参考 [PyTorch](https://pytorch.org/) 准备 Torch。
注意,OpenCompass 需要 `pytorch>=1.13`
```bash
conda create --name opencompass python=3.8 -y
conda activate opencompass
conda install pytorch torchvision -c pytorch
```
2. 安装 OpenCompass:
```bash
git clone https://github.com/opencompass/opencompass
cd opencompass
pip install -r requirments/runtime.txt
pip install -e .
```
3. 安装 humaneval(可选)
如果你希望在 humaneval 数据集上进行评估,请执行此步骤。
```
git clone https://github.com/openai/human-eval.git
cd human-eval
pip install -r requirments.txt
pip install -e .
```
请记住在源代码中删除第48-57行的注释,并取消对[第58行](https://github.com/openai/human-eval/blob/312c5e5532f0e0470bf47f77a6243e02a61da530/human_eval/execution.py#L58)的注释。
# 快速上手
# 任务运行和监控
## 评测任务发起
评测任务的程序入口为 `run.py`,使用方法如下:
```shell
run.py [-p PARTITION] [-q QUOTATYPE] [--debug] [-m MODE] [-r [REUSE]] [-w WORKDIR] [-l LARK] config
```
参数解释如下:
- -p 指定 slurm 分区;
- -q 指定 slurm quotatype (默认为 auto),可选 reserved, auto, spot;
- --debug 开启时,推理和评测任务会以单进程模式运行,且输出会实时回显,便于调试;
- -m 运行模式,默认为 all。可以指定为 infer 则仅运行推理,获得输出结果;如果在 {WORKDIR} 中已经有模型输出,则指定为 eval 仅运行评测,获得评测结果;如果在 results 中已有单项评测结果,则指定为 viz 仅运行可视化;指定为 all 则同时运行推理和评测。
- -r 重用已有的推理结果。如果后面跟有时间戳,则会复用工作路径下该时间戳的结果;否则则复用指定工作路径下的最新结果。
- -w 指定工作路径,默认为 ./outputs/default
- -l 打开飞书机器人状态上报。
以运行模式`-m all`为例,整体运行流如下:
1. 读取配置文件,解析出模型、数据集、评估器等配置信息
2. 评测任务主要分为推理 infer、评测 eval 和可视化 viz 三个阶段,其中推理和评测经过 Partitioner 进行任务切分后,交由 Runner 负责并行执行。单个推理和评测任务则被抽象成 OpenICLInferTask 和 OpenICLEvalTask。
3. 两阶段分别结束后,可视化阶段会读取 results 中的评测结果,生成可视化报告。
## 任务监控:飞书机器人
用户可以通过配置飞书机器人,实现任务状态的实时监控。飞书机器人的设置文档请[参考这里](https://open.feishu.cn/document/ukTMukTMukTM/ucTM5YjL3ETO24yNxkjN?lang=zh-CN#7a28964d)
配置方式:
1. 打开 `configs/lark.py` 文件,并在文件中加入以下行:
```python
lark_bot_url = 'YOUR_WEBHOOK_URL'
```
通常, Webhook URL 格式如 https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxxxxxxxxx 。
2. 在完整的评测配置中继承该文件:
```python
from mmengine.config import read_base
with read_base():
from .lark import lark_bot_url
```
3. 为了避免机器人频繁发消息形成骚扰,默认运行时状态不会自动上报。有需要时,可以通过 `-l``--lark` 启动状态上报:
```bash
python run.py configs/eval_demo.py -p {PARTITION} -l
```
## Summerizer介绍
主要用于可视化评测结果。
## 运行结果
所有运行结果会默认放在`outputs/default/`目录下,目录结构如下所示:
```
outputs/default/
├── 20200220_120000
├── ...
├── 20230220_183030
│   ├── configs
│   ├── logs
│   │   ├── eval
│   │   └── infer
│   ├── predictions
│   │   └── MODEL1
│   └── results
│ └── MODEL1
```
其中,每一个时间戳中存在以下内容:
- configs文件夹,用于存放以这个时间戳为输出目录的每次运行对应的配置文件;
- logs文件夹,用于存放推理和评测两个阶段的输出日志文件,各个文件夹内会以模型为子文件夹存放日志;
- predicitions文件夹,用于存放推理json结果,以模型为子文件夹;
- results文件夹,用于存放评测json结果,以模型为子文件夹
另外,所有指定-r 但是没有指定对应时间戳将会按照排序选择最新的文件夹作为输出目录。
from .afqmcd import * # noqa: F401, F403
from .agieval import * # noqa: F401, F403
from .arc import * # noqa: F401, F403
from .ax import * # noqa: F401, F403
from .bbh import * # noqa: F401, F403
from .boolq import * # noqa: F401, F403
from .bustum import * # noqa: F401, F403
from .c3 import * # noqa: F401, F403
from .cb import * # noqa: F401, F403
from .ceval import * # noqa: F401, F403
from .chid import * # noqa: F401, F403
from .civilcomments import * # noqa: F401, F403
from .cluewsc import * # noqa: F401, F403
from .cmnli import * # noqa: F401, F403
from .cmrc import * # noqa: F401, F403
from .commonsenseqa import * # noqa: F401, F403
from .copa import * # noqa: F401, F403
from .crowspairs import * # noqa: F401, F403
from .csl import * # noqa: F401, F403
from .drcd import * # noqa: F401, F403
from .drop import * # noqa: F401, F403
from .eprstmt import * # noqa: F401, F403
from .flores import * # noqa: F401, F403
from .GaokaoBench import * # noqa: F401, F403
from .govrepcrs import * # noqa: F401, F403
from .gsm8k import * # noqa: F401, F403
from .hellaswag import * # noqa: F401, F403
from .huggingface import * # noqa: F401, F403
from .humaneval import * # noqa: F401, F403
from .iwslt2017 import * # noqa: F401, F403
from .jigsawmultilingual import * # noqa: F401, F403
from .lambada import * # noqa: F401, F403
from .lcsts import * # noqa: F401, F403
from .math import * # noqa: F401, F403
from .mbpp import * # noqa: F401, F403
from .mmlu import * # noqa: F401, F403
from .multirc import * # noqa: F401, F403
from .narrativeqa import * # noqa: F401, F403
from .natural_question import * # noqa: F401, F403
from .obqa import * # noqa: F401, F403
from .piqa import * # noqa: F401, F403
from .qasper import * # noqa: F401, F403
from .qaspercut import * # noqa: F401, F403
from .race import * # noqa: F401, F403
from .realtoxicprompts import * # noqa: F401, F403
from .record import * # noqa: F401, F403
from .safety import * # noqa: F401, F403
from .siqa import * # noqa: F401, F403
from .storycloze import * # noqa: F401, F403
from .strategyqa import * # noqa: F401, F403
from .summedits import * # noqa: F401, F403
from .summscreen import * # noqa: F401, F403
from .TheoremQA import * # noqa: F401, F403
from .tnews import * # noqa: F401, F403
from .triviaqa import * # noqa: F401, F403
from .triviaqarc import * # noqa: F401, F403
from .truthfulqa import * # noqa: F401, F403
from .wic import * # noqa: F401, F4
from .winograd import * # noqa: F401, F403
from .winogrande import * # noqa: F401, F403
from .wsc import * # noqa: F401, F403
from .xcopa import * # noqa: F401, F403
from .xlsum import * # noqa: F401, F403
from .xsum import * # noqa: F401, F403
import json
import os.path as osp
from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from .math_equivalence import is_equiv
from .post_process import parse_math_answer
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class AGIEvalDataset(BaseDataset):
@staticmethod
def load(path: str, name: str, setting_name: str):
from .dataset_loader import load_dataset, load_dataset_as_result_schema
assert setting_name in 'zero-shot', 'only support zero-shot setting'
dataset_wo_label = load_dataset(name, setting_name, path)
dataset_with_label = load_dataset_as_result_schema(name, path)
dataset = []
for d1, d2 in zip(dataset_wo_label, dataset_with_label):
dataset.append({
'id': d2.index,
'problem_input': d1['context'],
'label': d2.label,
})
dataset = Dataset.from_list(dataset)
return dataset
@LOAD_DATASET.register_module()
class AGIEvalDataset_v2(BaseDataset):
@staticmethod
def load(path: str, name: str, setting_name: str):
assert setting_name in 'zero-shot', 'only support zero-shot setting'
filename = osp.join(path, name + '.jsonl')
with open(filename) as f:
_data = [json.loads(line.strip()) for line in f]
data = []
for _d in _data:
passage = _d['passage'] if _d['passage'] else ''
question = passage + _d['question']
options = '\n'.join(_d['options']) if _d['options'] else ''
label = _d['label'] if _d['label'] else _d['answer']
d = {'question': question, 'options': options, 'label': label}
data.append(d)
dataset = Dataset.from_list(data)
return dataset
@ICL_EVALUATORS.register_module()
class AGIEvalEvaluator(BaseEvaluator):
def score(self, predictions, references):
predictions = [parse_math_answer('', pred) for pred in predictions]
cnt = 0
for pred, ref in zip(predictions, references):
if is_equiv(pred, ref):
cnt += 1
score = cnt / len(predictions) * 100
return {'score': score}
# flake8: noqa
import pandas as pd
class TaskSchema(object):
def __init__(self,
passage=None,
question=None,
options=None,
label=None,
answer=None,
other=None):
self.passage = passage
self.question = question
self.options = options
self.label = label
self.answer = answer
self.other = other
def to_dict(self):
return {
'passage': self.passage,
'question': self.question,
'options': self.options,
'label': self.label,
'answer': self.answer,
'other': self.other
}
# define README.json
class AgiInstance(object):
def __init__(self, task_description, data_source, task_schema, output,
evaluation_metric, task_example):
self.task_description = task_description
self.data_source = data_source
self.task_schema = task_schema
self.output = output
self.evaluation_metric = evaluation_metric
self.task_example = task_example
def to_dict(self):
return {
'task description': self.task_description,
'data source': self.data_source,
'task schema': self.task_schema.to_dict(),
'output': self.output,
'evaluation metric': self.evaluation_metric,
'task example': self.task_example
}
class ChatGPTSchema(object):
def __init__(self, context=None, metadata=''):
self.context = context
self.metadata = metadata
def to_dict(self):
return {'context': self.context, 'metadata': self.metadata}
class ResultsForHumanSchema(object):
def __init__(self,
index,
problem_input,
label,
model_input='',
model_output='',
parse_result='',
first_stage_output='',
second_stage_input='',
is_correct=False):
self.index = index
self.problem_input = problem_input
self.model_input = model_input
self.model_output = model_output
self.parse_result = parse_result
self.label = label
self.first_stage_output = first_stage_output
self.second_stage_input = second_stage_input
self.is_correct = is_correct
def to_dict(self):
return {
'index': self.index,
'problem_input': self.problem_input,
'model_input': self.model_input,
'model_output': self.model_output,
'parse_result': self.parse_result,
'label': self.label,
'is_correct': self.is_correct,
'first_stage_output': self.first_stage_output,
'second_stage_input': self.second_stage_input,
}
@staticmethod
def to_tsv(result_list, path):
result_json = [item.to_dict() for item in result_list]
table = pd.json_normalize(result_json)
table.to_excel(path, index=False)
# flake8: noqa
import ast
import json
import os
import pandas as pd
import tiktoken
from tqdm import tqdm
from .constructions import ChatGPTSchema, ResultsForHumanSchema
from .utils import extract_answer, read_jsonl, save_jsonl
# define the datasets
english_qa_datasets = [
'lsat-ar', 'lsat-lr', 'lsat-rc', 'logiqa-en', 'sat-math', 'sat-en',
'aqua-rat', 'sat-en-without-passage', 'gaokao-english'
]
chinese_qa_datasets = [
'logiqa-zh', 'jec-qa-kd', 'jec-qa-ca', 'gaokao-chinese',
'gaokao-geography', 'gaokao-history', 'gaokao-biology', 'gaokao-chemistry',
'gaokao-physics', 'gaokao-mathqa'
]
english_cloze_datasets = ['math']
chinese_cloze_datasets = ['gaokao-mathcloze']
multi_choice_datasets = ['jec-qa-kd', 'jec-qa-ca', 'gaokao-physics']
math_output_datasets = ['gaokao-mathcloze', 'math']
def convert_zero_shot(line, dataset_name):
try:
passage = line['passage'] if line['passage'] is not None else ''
if dataset_name in english_qa_datasets:
option_string = 'ABCDEFG'
count = len(line['options'])
if count == 1:
count = 5
return passage + 'Q: ' + line['question'] + ' ' \
+ 'Answer Choices: ' + ' '.join(line['options']) + '\n' + \
'A: Among A through {}, the answer is'.format(option_string[count - 1])
elif dataset_name in chinese_qa_datasets:
option_string = 'ABCDEFG'
count = len(line['options'])
if count == 1:
count = 4
return passage + '问题:' + line['question'] + ' ' \
+ '选项:' + ' '.join(line['options']) + '\n' + \
'答案:从A到{}, 我们应选择'.format(option_string[count - 1])
elif dataset_name in english_cloze_datasets:
return passage + 'Q: ' + line['question'] + '\n' \
'A: The answer is'
elif dataset_name in chinese_cloze_datasets:
return passage + '问题:' + line['question'] + '\n' \
'答案:'
except NameError:
print('Dataset not defined.')
prefix = '该问题为单选题,所有选项中必有一个正确答案,且只有一个正确答案。\n'
def convert_zero_shot_CoT_stage1(line, dataset_name):
try:
passage = line['passage'] if line['passage'] is not None else ''
if dataset_name in english_qa_datasets:
return passage + 'Q: ' + line['question'] + ' ' \
+ 'Answer Choices: ' + ' '.join(line['options']) + '\n' + \
"Let's think step by step."
elif dataset_name in chinese_qa_datasets:
option_string = 'ABCDEFG'
count = len(line['options'])
if count == 1:
count = 4
return passage + '问题:' + line['question'] + ' ' \
+ '选项:' + ' '.join(line['options']) + '\n' + \
'从A到{}, 我们应选择什么?让我们逐步思考:'.format(option_string[count - 1])
elif dataset_name in english_cloze_datasets:
return passage + 'Q: ' + line['question'] + '\n' \
"A: Let's think step by step."
elif dataset_name in chinese_cloze_datasets:
return passage + '问题:' + line['question'] + '\n' \
'答案:让我们逐步思考:'
except NameError:
print('Dataset not defined.')
# process few-shot raw_prompts
def combine_prompt(prompt_path,
dataset_name,
load_explanation=True,
chat_mode=False):
skip_passage = False
if dataset_name == 'sat-en-without-passage':
skip_passage = True
dataset_name = 'sat-en'
demostrations = []
# read the prompts by context and explanation
context_row = [0, 1, 3, 5, 7, 9]
explanation_row = [0, 2, 4, 6, 8, 10]
raw_prompts_context = pd.read_csv(prompt_path,
header=0,
skiprows=lambda x: x not in context_row,
keep_default_na=False)
raw_prompts_explanation = pd.read_csv(
prompt_path,
header=0,
skiprows=lambda x: x not in explanation_row,
keep_default_na=False).replace(r'\n\n', '\n', regex=True)
contexts = []
for line in list(raw_prompts_context[dataset_name]):
if line:
# print(line)
contexts.append(ast.literal_eval(line))
explanations = [
exp for exp in raw_prompts_explanation[dataset_name] if exp
]
for idx, (con, exp) in enumerate(zip(contexts, explanations)):
passage = con['passage'] if con[
'passage'] is not None and not skip_passage else ''
question = con['question']
options = con['options'] if con['options'] is not None else ''
label = con['label'] if con['label'] is not None else ''
answer = con[
'answer'] if 'answer' in con and con['answer'] is not None else ''
if dataset_name in english_qa_datasets:
question_input = 'Problem {}. '.format(idx + 1) + passage + ' ' + question + '\n' \
+ 'Choose from the following options: ' + ' '.join(options) + '\n'
question_output = (('Explanation for Problem {}: '.format(idx + 1) + exp + '\n') if load_explanation else '') \
+ 'The answer is therefore {}'.format(label)
elif dataset_name in chinese_qa_datasets:
question_input = '问题 {}. '.format(idx + 1) + passage + ' ' + question + '\n' \
+ '从以下选项中选择: ' + ' '.join(options) + '\n'
question_output = (('问题 {}的解析: '.format(idx + 1) + exp + '\n') if load_explanation else '') \
+ '答案是 {}'.format(label)
elif dataset_name in english_cloze_datasets:
question_input = 'Problem {}. '.format(idx + 1) + question + '\n'
question_output = (('Explanation for Problem {}: '.format(idx + 1) + exp + '\n') if load_explanation else '') \
+ 'The answer is therefore {}'.format(answer)
elif dataset_name in chinese_cloze_datasets:
question_input = '问题 {}. '.format(idx + 1) + question + '\n'
question_output = (('问题 {}的解析: '.format(idx + 1) + exp + '\n') if load_explanation else '') \
+ '答案是 {}'.format(answer)
else:
raise ValueError(
f'During loading few-sot examples, found unknown dataset: {dataset_name}'
)
if chat_mode:
demostrations.append((question_input, question_output))
else:
demostrations.append(question_input + question_output + '\n')
return demostrations
enc = None
def _lazy_load_enc():
global enc
if enc is None:
enc = tiktoken.encoding_for_model('gpt-4')
# cut prompt if reach max token length
def concat_prompt(demos,
dataset_name,
max_tokens,
end_of_example='\n',
verbose=False):
_lazy_load_enc()
demostration_en = 'Here are the answers for the problems in the exam.\n'
demostration_zh = '以下是考试中各个问题的答案。\n'
for i in range(len(demos)):
# print(len(enc.encode(demostration_en)), len(enc.encode(demostration_zh)))
if dataset_name in english_qa_datasets:
demostration_en = demostration_en + demos[i] + end_of_example
elif dataset_name in chinese_qa_datasets:
demostration_zh = demostration_zh + demos[i] + end_of_example
elif dataset_name in english_cloze_datasets:
demostration_en = demostration_en + demos[i] + end_of_example
elif dataset_name in chinese_cloze_datasets:
demostration_zh = demostration_zh + demos[i] + end_of_example
# break if reach max token limit
if len(enc.encode(demostration_en)) < max_tokens and len(
enc.encode(demostration_zh)) < max_tokens:
output = demostration_en if len(demostration_en) > len(
demostration_zh) else demostration_zh
prompt_num = i + 1
else:
break
if verbose:
print('max_tokens set as ', max_tokens, 'actual_tokens is',
len(enc.encode(output)), 'num_shot is', prompt_num)
return output, prompt_num
def concat_prompt_chat_mode(demos,
dataset_name,
max_tokens,
end_of_example='\n',
verbose=False):
_lazy_load_enc()
answers = []
sentences = ''
for i in range(len(demos)):
answers += [
{
'role': 'user',
'content': demos[i][0]
},
{
'role': 'assistant',
'content': demos[i][1]
},
]
sentences += json.dumps(answers[-1])
# break if reach max token limit
if len(enc.encode(sentences)) > max_tokens:
answers.pop()
answers.pop()
break
if verbose:
print('max_tokens set as ', max_tokens, 'actual_tokens is',
len(enc.encode(sentences)), 'num_shot is',
len(answers) // 2)
return answers, len(answers) // 2
def convert_few_shot(line, dataset_name, demo, n_shot, chat_mode=False):
passage = line['passage'] if line['passage'] is not None else ''
question = line['question']
options = line['options'] if line['options'] is not None else ''
if dataset_name in english_qa_datasets:
question_input = 'Problem {}. '.format(n_shot + 1) + passage + ' ' + question + '\n' \
+ 'Choose from the following options: ' + ' '.join(options) + '\n'
# + "Explanation for Problem {}: ".format(n_shot + 1)
if dataset_name in chinese_qa_datasets:
question_input = '问题 {}. '.format(n_shot + 1) + passage + ' ' + question + '\n' \
+ '从以下选项中选择: ' + ' '.join(options) + '\n'
# + "问题 {}的解析: ".format(n_shot + 1)
if dataset_name in english_cloze_datasets:
question_input = 'Problem {}. '.format(n_shot + 1) + question + '\n'
# + "Explanation for Problem {}: ".format(n_shot + 1)
if dataset_name in chinese_cloze_datasets:
question_input = '问题 {}. '.format(n_shot + 1) + question + '\n'
# + "问题 {}的解析: ".format(n_shot + 1)
if chat_mode:
return demo + [
{
'role': 'user',
'content': question_input
},
]
else:
return demo + question_input
def load_dataset(dataset_name,
setting_name,
parent_path,
prompt_path=None,
max_tokens=None,
end_of_example='\n',
chat_mode=False,
verbose=False):
test_path = os.path.join(parent_path, dataset_name + '.jsonl')
loaded_jsonl = read_jsonl(test_path)
processed = []
if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
# process demo once if it is few-shot-CoT
processed_demos = combine_prompt(
prompt_path,
dataset_name,
load_explanation=setting_name == 'few-shot-CoT',
chat_mode=chat_mode)
if chat_mode:
chosen_prompt, n_shot = concat_prompt_chat_mode(processed_demos,
dataset_name,
max_tokens,
end_of_example,
verbose=verbose)
else:
chosen_prompt, n_shot = concat_prompt(processed_demos,
dataset_name,
max_tokens,
end_of_example,
verbose=verbose)
if verbose:
loaded_jsonl = tqdm(loaded_jsonl)
for meta_idx, line in enumerate(loaded_jsonl):
if setting_name == 'zero-shot':
ctxt = convert_zero_shot(line, dataset_name)
elif setting_name == 'zero-shot-CoT':
ctxt = convert_zero_shot_CoT_stage1(line, dataset_name)
elif setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
ctxt = convert_few_shot(line, dataset_name, chosen_prompt, n_shot,
chat_mode)
try:
new_instance = ChatGPTSchema(context=ctxt, metadata=meta_idx)
processed.append(new_instance.to_dict())
except NameError:
print('Dataset not defined.')
return processed
def generate_second_stage_input(dataset_name,
input_list,
output_list,
with_format_prompt=False):
try:
english_format_prompt = 'Based on the previous results, your task is to extract the final answer and provide the output enclosed in brackets【】, such as 【0】 or 【A】.'
chinese_format_prompt = '根据以上内容,你的任务是把最终的答案提取出来并填在【】中,例如【0】或者【A】。'
if dataset_name in english_qa_datasets:
prompt_suffix = 'Therefore, among A through E, the answer is'
if with_format_prompt:
prompt_suffix = english_format_prompt + prompt_suffix
elif dataset_name in chinese_qa_datasets:
prompt_suffix = '因此,从A到D, 我们应选择'
if with_format_prompt:
prompt_suffix = chinese_format_prompt + prompt_suffix
elif dataset_name in english_cloze_datasets:
prompt_suffix = 'Therefore, the answer is'
if with_format_prompt:
prompt_suffix = english_format_prompt + prompt_suffix
elif dataset_name in chinese_cloze_datasets:
prompt_suffix = '因此,答案是'
if with_format_prompt:
prompt_suffix = chinese_format_prompt + prompt_suffix
except NameError:
print('Dataset not defined.')
processed = []
for i in range(len(input_list)):
ctxt = '{0}\n{1}\n{2}'.format(input_list[i]['context'],
extract_answer(output_list[i]),
prompt_suffix)
new_instance = ChatGPTSchema(context=ctxt,
metadata=input_list[i]['metadata'])
processed.append(new_instance.to_dict())
return processed
def load_dataset_as_result_schema(dataset_name, parent_path):
test_path = os.path.join(parent_path, dataset_name + '.jsonl')
loaded_jsonl = read_jsonl(test_path)
processed = []
for i, line in enumerate(loaded_jsonl):
problem_input = convert_zero_shot(line, dataset_name)
processed.append(
ResultsForHumanSchema(
index=i,
problem_input=problem_input,
label=line['label'] if line['label'] else line['answer'],
))
return processed
if __name__ == '__main__':
# set variables
parent_dir = '../../data/V1_1/'
raw_prompt_path = '../data/few_shot_prompts.csv'
# set dataset name to process
setting_name = 'few-shot-CoT' # setting_name can be chosen from ["zero-shot", "zero-shot-CoT", "few-shot-CoT"]
data_name = 'jec-qa-kd'
save_dir = '../../experiment_input/{}/'.format(setting_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
processed_data = load_dataset(data_name,
setting_name,
parent_dir,
prompt_path=raw_prompt_path,
max_tokens=2048)
save_jsonl(processed_data,
os.path.join(save_dir, '{}.jsonl'.format(data_name)))
# flake8: noqa
from . import dataset_loader, utils
from .math_equivalence import is_equiv
def convert_to_set(item):
if isinstance(item, list):
return set(item)
if isinstance(item, str):
return {item}
if item is None:
return {}
raise ValueError("Input can't parse:", item)
def evaluate_single_sample(dataset_name, prediction, label):
if dataset_name in dataset_loader.multi_choice_datasets:
p = convert_to_set(prediction)
l = convert_to_set(label)
return p == l
elif dataset_name in dataset_loader.math_output_datasets:
return is_equiv(prediction, label)
else:
return prediction == label
# def evaluate(dataset_name, prediction_list, label_list):
# correct = 0
# if dataset_name in multi_choice_datasets:
# for prediction, label in zip(prediction_list, label_list):
# p = convert_to_set(prediction)
# l = convert_to_set(label)
# if p == l:
# correct += 1
# elif dataset_name in math_output_datasets:
# for prediction, label in zip(prediction_list, label_list):
# if is_equiv(prediction, label):
# correct += 1
# else:
# for prediction, label in zip(prediction_list, label_list):
# if prediction == label:
# correct += 1
# return "{0:.2%}".format(correct / len(label_list))
# flake8: noqa
# code from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py
def _fix_fracs(string):
substrs = string.split('\\frac')
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += '\\frac'
if substr[0] == '{':
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != '{':
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}{' + b + '}' + post_substr
else:
new_str += '{' + a + '}{' + b + '}'
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}' + b + post_substr
else:
new_str += '{' + a + '}' + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split('/')) != 2:
return string
a = string.split('/')[0]
b = string.split('/')[1]
try:
a = int(a)
b = int(b)
assert string == '{}/{}'.format(a, b)
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
return new_string
except:
return string
def _remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if '\\text{ ' in string:
splits = string.split('\\text{ ')
assert len(splits) == 2
return splits[0]
else:
return string
def _fix_sqrt(string):
if '\\sqrt' not in string:
return string
splits = string.split('\\sqrt')
new_string = splits[0]
for split in splits[1:]:
if split[0] != '{':
a = split[0]
new_substr = '\\sqrt{' + a + '}' + split[1:]
else:
new_substr = '\\sqrt' + split
new_string += new_substr
return new_string
def _strip_string(string):
# linebreaks
string = string.replace('\n', '')
# print(string)
# remove inverse spaces
string = string.replace('\\!', '')
# print(string)
# replace \\ with \
string = string.replace('\\\\', '\\')
# print(string)
# replace tfrac and dfrac with frac
string = string.replace('tfrac', 'frac')
string = string.replace('dfrac', 'frac')
# print(string)
# remove \left and \right
string = string.replace('\\left', '')
string = string.replace('\\right', '')
# print(string)
# Remove circ (degrees)
string = string.replace('^{\\circ}', '')
string = string.replace('^\\circ', '')
# remove dollar signs
string = string.replace('\\$', '')
# remove units (on the right)
string = _remove_right_units(string)
# remove percentage
string = string.replace('\\%', '')
string = string.replace('\%', '')
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(' .', ' 0.')
string = string.replace('{.', '{0.')
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == '.':
string = '0' + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split('=')) == 2:
if len(string.split('=')[0]) <= 2:
string = string.split('=')[1]
# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)
# remove spaces
string = string.replace(' ', '')
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == '0.5':
string = '\\frac{1}{2}'
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print('WARNING: Both None')
return True
if str1 is None or str2 is None:
return False
try:
ss1 = _strip_string(str1)
ss2 = _strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except:
return str1 == str2
# flake8: noqa
import json
import re
from . import dataset_loader
def extract_last_line(string):
lines = string.split('\n')
for item in lines[::-1]:
if item.strip() != '':
string = item
break
return string
def remove_few_shot_prefix(string: str):
prefix_list = ['The answer is therefore', '答案是']
for prefix in prefix_list:
if string.startswith(prefix):
string = string[len(prefix):].strip()
elif prefix in string:
index = string.rfind(prefix)
if index >= 0:
string = string[index + len(prefix):].strip()
return string
def try_parse_few_shot_qa_single_answer(string, setting_name, language='en'):
if setting_name == 'few-shot-CoT':
string = extract_last_line(string)
if language == 'en':
pattern = 'answer is .*?([A-G])'
match = re.search(pattern, string)
elif language == 'zh':
pattern = '答案是.*?([A-G])'
match = re.search(pattern, string)
else:
raise ValueError('Unknown language {0}'.format(language))
if match:
return match.group(1)
else:
return None
def try_parse_few_shot_pattern(string: str, dataset_name, setting_name):
if setting_name == 'few-shot-CoT':
string = extract_last_line(string)
if dataset_name in dataset_loader.chinese_cloze_datasets:
return string.startswith('答案是')
elif dataset_name in dataset_loader.english_cloze_datasets:
return string.startswith('The answer is therefore')
elif dataset_name in dataset_loader.chinese_qa_datasets:
pattern = '答案是.*?([A-G])'
match = re.search(pattern, string)
return match is not None
elif dataset_name in dataset_loader.english_qa_datasets:
pattern = 'answer is .*?([A-G])'
match = re.search(pattern, string)
return match is not None
return False
def parse_few_shot_qa_single_answer(string, setting_name, language='en'):
answer = try_parse_few_shot_qa_single_answer(string, setting_name,
language)
if answer is None:
return find_first_capital_letter(string)
else:
return answer
def find_first_capital_letter(answer):
letter_set = {'A', 'B', 'C', 'D', 'E', 'F'}
for c in answer:
if c in letter_set:
return c
# print("Can't find capital letter in:", answer)
return ''
def extract_answer_in_bracket(answer, prefix='【', suffix='】'):
if prefix not in answer and suffix not in answer:
# print("doesn't found special tokens in:", answer)
return ''
s = answer.index(prefix) + len(prefix)
t = answer.index(suffix)
ret = answer[s:t]
return ret
def parse_math_answer(setting_name, raw_string):
if setting_name == 'few-shot-CoT':
raw_string = extract_last_line(raw_string)
if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
raw_string = remove_few_shot_prefix(raw_string)
return raw_string
def remove_boxed(s):
left = '\\boxed{'
try:
assert s[:len(left)] == left
assert s[-1] == '}'
answer = s[len(left):-1]
if '=' in answer:
answer = answer.split('=')[-1].lstrip(' ')
return answer
except:
return None
def last_boxed_only_string(string):
idx = string.rfind('\\boxed')
if idx < 0:
idx = string.rfind('\\fbox')
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == '{':
num_left_braces_open += 1
if string[i] == '}':
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx == None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def get_answer_with_dollar_sign(s):
first_pattern = '\$(.*)\$'
last_match = None
matches = re.findall(first_pattern, s)
if matches:
last_match = matches[-1]
if '=' in last_match:
last_match = last_match.split('=')[-1].lstrip(' ')
return last_match
def get_answer_without_dollar_sign(s):
last_match = None
if '=' in s:
last_match = s.split('=')[-1].lstrip(' ').rstrip('.')
if '\\n' in last_match:
last_match = last_match.split('\\n')[0]
else:
pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])'
matches = re.findall(pattern, s)
if matches:
last_match = matches[-1]
return last_match
raw_string = remove_few_shot_prefix(raw_string)
if '\\boxed' in raw_string:
answer = remove_boxed(last_boxed_only_string(raw_string))
else:
answer = get_answer_with_dollar_sign(raw_string)
if not answer:
answer = get_answer_without_dollar_sign(raw_string)
return answer
def parse_qa_multiple_answer(string, setting_name):
if setting_name == 'few-shot-CoT':
string = extract_last_line(string)
pattern = '\(*([A-Z])\)*'
match = re.findall(pattern, string)
if match:
return match
return []
def post_process(dataset_name, setting_name, prediction):
if dataset_name in dataset_loader.english_cloze_datasets or dataset_name in dataset_loader.chinese_cloze_datasets:
return parse_math_answer(setting_name, prediction)
if dataset_name in ['jec-qa-kd', 'jec-qa-ca', 'gaokao-physics']:
return parse_qa_multiple_answer(prediction, setting_name)
# all other datasets are QA problems with single answer
if 'zero-shot' in setting_name:
answer = find_first_capital_letter(prediction)
return answer
# all other datasets are QA problems with single answer and setting_name are few-shot
language = 'en' if dataset_name in dataset_loader.english_qa_datasets else 'zh'
if dataset_name in dataset_loader.english_qa_datasets or dataset_name in dataset_loader.chinese_qa_datasets:
return parse_few_shot_qa_single_answer(prediction, setting_name,
language)
else:
raise ValueError(f'Unsupported dataset name {dataset_name}')
# flake8: noqa
import json
def read_jsonl(path):
with open(path, encoding='utf8') as fh:
results = []
for line in fh:
if line is None:
continue
try:
results.append(json.loads(line) if line != 'null' else line)
except Exception as e:
print(e)
print(path)
print(line)
raise e
return results
def save_jsonl(lines, directory):
with open(directory, 'w', encoding='utf8') as f:
for line in lines:
f.write(json.dumps(line, ensure_ascii=False) + '\n')
def extract_answer(js):
try:
if js is None or js == 'null':
return ''
answer = ''
if isinstance(js, str):
answer = js
elif 'text' in js['choices'][0]:
answer = js['choices'][0]['text']
else:
answer = js['choices'][0]['message']['content']
# answer = js['']
return answer
except Exception as e:
# print(e)
# print(js)
return ''
import json
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class C3Dataset(BaseDataset):
@staticmethod
def load(path: str):
with open(path) as f:
data = json.load(f)
rows = []
for _, row in enumerate(data):
content = row[0]
content_str = ' '.join(
[''.join(paragraph) for paragraph in content])
for question in row[1]:
label = question['choice'].index(question['answer'])
length = len(question['choice'])
if length < 4:
fill_value = question['choice'][0] # 以第一个值为填充值
fill_count = 4 - length # 需要填充的数量
question['choice'] += [fill_value] * fill_count # 填充
rows.append({
'content': content_str,
'question': question['question'],
'choices': question['choice'],
'choice0': question['choice'][0],
'choice1': question['choice'][1],
'choice2': question['choice'][2],
'choice3': question['choice'][3],
'label': label
})
dataset = Dataset.from_dict({
'content': [row['content'] for row in rows],
'question': [row['question'] for row in rows],
'choice0': [row['choice0'] for row in rows],
'choice1': [row['choice1'] for row in rows],
'choice2': [row['choice2'] for row in rows],
'choice3': [row['choice3'] for row in rows],
'choices': [row['choices'] for row in rows],
'label': [row['label'] for row in rows]
})
return dataset
@LOAD_DATASET.register_module()
class C3Dataset_V2(BaseDataset):
@staticmethod
def load(path: str):
with open(path) as f:
raw = json.load(f)
data = []
for line in raw:
content = ''.join([''.join(paragraph) for paragraph in line[0]])
for question in line[1]:
label = question['choice'].index(question['answer'])
label = 'ABCD'[label]
while len(question['choice']) < 4:
question['choice'].append('[NULL]')
data.append({
'content': content,
'question': question['question'],
'choice0': question['choice'][0],
'choice1': question['choice'][1],
'choice2': question['choice'][2],
'choice3': question['choice'][3],
'label': label
})
return Dataset.from_list(data)
import json
from datasets import Dataset, load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class CluewscDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def preprocess(example):
text_list = list(example['text'])
# span1 may have 1 or more than 1 words
# span2 is the pronoun and has only 1 word
text_list[example['target']
['span2_index']] = example['target']['span1_text']
example['new_text'] = ''.join(text_list)
if example['label'] == 'true':
example['answer'] = 1
else:
example['answer'] = 0
example['span1'] = example['target']['span1_text']
example['span2'] = example['target']['span2_text']
del example['target']
return example
dataset = dataset.map(preprocess)
return dataset
@LOAD_DATASET.register_module()
class CluewscDataset_V2(BaseDataset):
@staticmethod
def load(path):
data = []
with open(path, 'r') as f:
for line in f:
line = json.loads(line)
item = {
'span1': line['target']['span1_text'],
'span2': line['target']['span2_text'],
'text': line['text'],
'label': {
'true': 'A',
'false': 'B'
}[line['label']],
}
data.append(item)
return Dataset.from_list(data)
import json
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class COPADataset_V2(BaseDataset):
@staticmethod
def load(path):
dataset = []
with open(path, 'r') as f:
for line in f:
line = json.loads(line)
line['label'] = 'AB'[line['label']]
dataset.append(line)
return Dataset.from_list(dataset)
import json
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class DRCDDataset(BaseDataset):
@staticmethod
def load(path: str):
with open(path) as f:
data = json.load(f)
# 将原始数据转换为所需的格式
rows = []
for index, paragraphs in enumerate(data['data']):
for paragraph in paragraphs['paragraphs']:
context = paragraph['context']
for question in paragraph['qas']:
answers = question['answers']
unique_answers = list(set([a['text'] for a in answers]))
rows.append({
'context': context,
'question': question['question'],
'answers': unique_answers
})
# 创建 Dataset
dataset = Dataset.from_dict({
'context': [row['context'] for row in rows],
'question': [row['question'] for row in rows],
'answers': [row['answers'] for row in rows]
})
return dataset
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class hellaswagDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def preprocess(example):
for i in range(4):
example[chr(ord('A') + i)] = example['endings'][i]
return example
dataset = dataset.map(preprocess).remove_columns(['endings'])
return dataset
@LOAD_DATASET.register_module()
class hellaswagDataset_V2(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def preprocess(example):
for i in range(4):
example[chr(ord('A') + i)] = example['endings'][i]
if example['label']:
example['label'] = 'ABCD'[int(example['label'])]
else:
example['label'] = 'NULL'
return example
dataset = dataset.map(preprocess).remove_columns(['endings'])
return dataset
import os.path as osp
import tempfile
from typing import List
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, TEXT_POSTPROCESSORS
@ICL_EVALUATORS.register_module()
class HumanEvaluator(BaseEvaluator):
"""Evaluator for human eval."""
def __init__(self, k: List[int] = [1, 10, 100]) -> None:
try:
from human_eval.data import HUMAN_EVAL, write_jsonl
from human_eval.evaluation import evaluate_functional_correctness
self.write_jsonl = write_jsonl
self.HUMAN_EVAL = HUMAN_EVAL
self.eval = evaluate_functional_correctness
except ImportError:
raise ImportError('Please install human_eval following'
'https://github.com/openai/human-eval/tree/'
'master#installation first.')
self.k = k
super().__init__()
def score(self, predictions, references):
predictions = [{
'task_id': f'HumanEval/{i}',
'completion': predictions[i]
} for i in range(len(predictions))]
with tempfile.TemporaryDirectory() as tmp_dir:
out_dir = osp.join(tmp_dir, 'human_eval.json')
self.write_jsonl(out_dir, predictions)
score = self.eval(
out_dir,
self.k,
n_workers=4,
timeout=3.0,
problem_file=self.HUMAN_EVAL)
return {f'humaneval_{k}': score[k] * 100 for k in score}
@TEXT_POSTPROCESSORS.register_module('humaneval')
def humaneval_postprocess(text: str) -> str:
text = text.split('\n\n')[0]
if '```' in text:
text = text.split('```')[1]
if text.startswith('def'):
text = '\n'.join(text.split('\n')[1:])
if not text.startswith(' '):
if text.startswith(' '):
text = ' ' + text.lstrip()
else:
text = '\n'.join([' ' + line for line in text.split('\n')])
return text
import csv
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class JigsawMultilingualDataset(BaseDataset):
@staticmethod
def load(path, label, lang):
assert lang in ['es', 'fr', 'it', 'pt', 'ru', 'tr']
dataset = DatasetDict()
data_list = list()
idx = 0
with open(path) as file, open(label) as label:
text_reader = csv.reader(file)
label_reader = csv.reader(label)
for text, target in zip(text_reader, label_reader):
if text[2] == lang:
assert text[0] == target[0]
data_list.append({
'idx': idx,
'text': text[1],
'label': int(target[1]),
'choices': ['no', 'yes']
})
idx += 1
dataset['test'] = Dataset.from_list(data_list)
return dataset
import re
import string
from datasets import DatasetDict, load_dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from opencompass.utils.text_postprocessors import general_postprocess
from .base import BaseDataset
@LOAD_DATASET.register_module()
class lambadaDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs, split='test')
def preprocess(example):
prompt, target = example['text'].strip().rsplit(' ', 1)
example['prompt'] = prompt
example['label'] = target
return example
dataset = dataset.map(preprocess)
return DatasetDict({'test': dataset})
@ICL_EVALUATORS.register_module()
class LambadaEvaluator(BaseEvaluator):
def __init__(self) -> None:
super().__init__()
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
score = 0.0
for pred, refer in zip(predictions, references):
pred = pred.strip().split(' ')[0]
pred = re.split(f'[{string.punctuation}]', pred)[0]
score += general_postprocess(pred) == general_postprocess(refer)
score = 100.0 * score / len(predictions)
return dict(accuracy=score)
import json
from datasets import Dataset, DatasetDict
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS
from .base import BaseDataset
@LOAD_DATASET.register_module()
class MATHDataset(BaseDataset):
@staticmethod
def load(path: str):
def remove_boxed(s):
left = '\\boxed{'
try:
assert s[:len(left)] == left
assert s[-1] == '}'
return s[len(left):-1]
except Exception:
return None
def last_boxed_only_string(string):
idx = string.rfind('\\boxed')
if idx < 0:
idx = string.rfind('\\fbox')
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == '{':
num_left_braces_open += 1
if string[i] == '}':
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
dataset = DatasetDict()
data = json.load(open(path))
raw_data = []
for i in data.keys():
raw_data.append({
'problem':
data[i]['problem'],
'solution':
remove_boxed(last_boxed_only_string(data[i]['solution']))
})
dataset['test'] = Dataset.from_list(raw_data)
dataset['train'] = Dataset.from_list(raw_data)
return dataset
@TEXT_POSTPROCESSORS.register_module('math')
def math_postprocess(text: str) -> str:
SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''),
(r'\ ', ''), (' ', ''), ('mbox', 'text'),
(',\\text{and}', ','), ('\\text{and}', ','),
('\\text{m}', '\\text{}'), ('\le', '<')]
REMOVED_EXPRESSIONS = [
'square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'ft',
'hours', 'km', 'units', '\\ldots', 'sue', 'points', 'feet', 'minutes',
'digits', 'cents', 'degrees', 'cm', 'gm', 'pounds', 'meters', 'meals',
'edges', 'students', 'childrentickets', 'multiples', '\\text{s}',
'\\text{.}', '\\text{\ns}', '\\text{}^2', '\\text{}^3', '\\text{\n}',
'\\text{}', r'\mathrm{th}', r'^\circ', r'^{\circ}', r'\;', r',\!',
'{,}', '"', '\\dots', '\n', '\r', '\f'
]
import re
def normalize_final_answer(final_answer: str) -> str:
"""Normalize a final answer to a quantitative reasoning question."""
# final_answer = final_answer.split('=')[-1]
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, '')
# Extract answer that is in LaTeX math, is bold,
# is surrounded by a box, etc.
final_answer = re.sub(r'(\\text\{)(.*?)(\})', '\\2', final_answer)
final_answer = re.sub(r'(\\textbf\{)(.*?)(\})', '\\2', final_answer)
final_answer = re.sub(r'(\\overline\{)(.*?)(\})', '\\2', final_answer)
final_answer = re.sub(r'(\\boxed\{)(.*)(\})', '\\2', final_answer)
assert '\n' not in final_answer and '\r' not in final_answer and '\f' not in final_answer
if len(re.findall(r'finalansweris(.*)', final_answer)) > 0:
final_answer = re.findall(r'finalansweris(.*)', final_answer)[-1]
if len(re.findall(r'oxed\{(.*?)\}', final_answer)) > 0:
final_answer = re.findall(r'oxed\{(.*?)\}', final_answer)[-1]
if len(re.findall(r'\$(.*?)\$', final_answer)) > 0:
final_answer = re.findall(r'\$(.*?)\$', final_answer)[-1]
final_answer = final_answer.strip()
if 'rac' in final_answer and '\\frac' not in final_answer:
final_answer = final_answer.replace('rac', '\\frac')
# Normalize shorthand TeX:
# \fracab -> \frac{a}{b}
# \frac{abc}{bef} -> \frac{abc}{bef}
# \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b
final_answer = re.sub(r'(frac)([^{])(.)', 'frac{\\2}{\\3}',
final_answer)
final_answer = re.sub(r'(sqrt)([^{])', 'sqrt{\\2}', final_answer)
final_answer = final_answer.replace('$', '')
# Normalize 100,000 -> 100000
if final_answer.replace(',', '').isdigit():
final_answer = final_answer.replace(',', '')
return final_answer
for maybe_ans in text.split('.'):
if 'final answer' in maybe_ans.lower():
return normalize_final_answer(maybe_ans)
return normalize_final_answer(text.split('.')[0])
# return normalize_final_answer(
# text.split('Final Answer: ', 1)[-1].split('\n\n')[0])
@ICL_EVALUATORS.register_module()
class MATHEvaluator(BaseEvaluator):
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
correct = 0
count = 0
for i, j in zip(predictions, references):
count += 1
if self.is_equiv(i, j):
correct += 1
result = {'accuracy': 100 * correct / count}
return result
def _fix_fracs(self, string):
substrs = string.split('\\frac')
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += '\\frac'
if substr[0] == '{':
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != '{':
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}{' + b + '}' + post_substr
else:
new_str += '{' + a + '}{' + b + '}'
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}' + b + post_substr
else:
new_str += '{' + a + '}' + b
string = new_str
return string
def _fix_a_slash_b(self, string):
if len(string.split('/')) != 2:
return string
a = string.split('/')[0]
b = string.split('/')[1]
try:
a = int(a)
b = int(b)
assert string == '{}/{}'.format(a, b)
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
return new_string
except AssertionError:
return string
def _remove_right_units(self, string):
# "\\text{ " only ever occurs (at least in the val set) when describing
# units
if '\\text{ ' in string:
splits = string.split('\\text{ ')
assert len(splits) == 2
return splits[0]
else:
return string
def _fix_sqrt(self, string):
if '\\sqrt' not in string:
return string
splits = string.split('\\sqrt')
new_string = splits[0]
for split in splits[1:]:
if split[0] != '{':
a = split[0]
new_substr = '\\sqrt{' + a + '}' + split[1:]
else:
new_substr = '\\sqrt' + split
new_string += new_substr
return new_string
def _strip_string(self, string):
# linebreaks
string = string.replace('\n', '')
# remove inverse spaces
string = string.replace('\\!', '')
# replace \\ with \
string = string.replace('\\\\', '\\')
# replace tfrac and dfrac with frac
string = string.replace('tfrac', 'frac')
string = string.replace('dfrac', 'frac')
# remove \left and \right
string = string.replace('\\left', '')
string = string.replace('\\right', '')
# Remove circ (degrees)
string = string.replace('^{\\circ}', '')
string = string.replace('^\\circ', '')
# remove dollar signs
string = string.replace('\\$', '')
# remove units (on the right)
string = self._remove_right_units(string)
# remove percentage
string = string.replace('\\%', '')
string = string.replace('\%', '') # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively,
# add "0" if "." is the start of the string
string = string.replace(' .', ' 0.')
string = string.replace('{.', '{0.')
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == '.':
string = '0' + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split('=')) == 2:
if len(string.split('=')[0]) <= 2:
string = string.split('=')[1]
# fix sqrt3 --> sqrt{3}
string = self._fix_sqrt(string)
# remove spaces
string = string.replace(' ', '')
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works
# with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = self._fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == '0.5':
string = '\\frac{1}{2}'
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix
# in case the model output is X/Y
string = self._fix_a_slash_b(string)
return string
def is_equiv(self, str1, str2, verbose=False):
if str1 is None and str2 is None:
print('WARNING: Both None')
return True
if str1 is None or str2 is None:
return False
try:
ss1 = self._strip_string(str1)
ss2 = self._strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except: # noqa
return str1 == str2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment