"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "beb932c5d111872c5e45387e7b1b2b3dd0524a47"
Unverified Commit 8d9cee06 authored by Hubert's avatar Hubert Committed by GitHub
Browse files

[Feat] update postprocessor to get first option more accurately (#193)

* [Feat] update postprocessor to get first option

* minor fix

* minor fix
parent 14332e08
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import ARCDataset from opencompass.datasets import ARCDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
ARC_c_reader_cfg = dict( ARC_c_reader_cfg = dict(
input_columns=["question", "textA", "textB", "textC", "textD"], input_columns=["question", "textA", "textB", "textC", "textD"],
...@@ -28,7 +28,7 @@ ARC_c_infer_cfg = dict( ...@@ -28,7 +28,7 @@ ARC_c_infer_cfg = dict(
ARC_c_eval_cfg = dict( ARC_c_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
) )
ARC_c_datasets = [ ARC_c_datasets = [
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import ARCDataset from opencompass.datasets import ARCDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
ARC_e_reader_cfg = dict( ARC_e_reader_cfg = dict(
input_columns=["question", "textA", "textB", "textC", "textD"], input_columns=["question", "textA", "textB", "textC", "textD"],
...@@ -28,7 +28,7 @@ ARC_e_infer_cfg = dict( ...@@ -28,7 +28,7 @@ ARC_e_infer_cfg = dict(
ARC_e_eval_cfg = dict( ARC_e_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
) )
ARC_e_datasets = [ ARC_e_datasets = [
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import AXDataset_V2 from opencompass.datasets import AXDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
AX_b_reader_cfg = dict( AX_b_reader_cfg = dict(
input_columns=["sentence1", "sentence2"], input_columns=["sentence1", "sentence2"],
...@@ -28,7 +28,7 @@ AX_b_infer_cfg = dict( ...@@ -28,7 +28,7 @@ AX_b_infer_cfg = dict(
AX_b_eval_cfg = dict( AX_b_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
AX_b_datasets = [ AX_b_datasets = [
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import AXDataset_V2 from opencompass.datasets import AXDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
AX_g_reader_cfg = dict( AX_g_reader_cfg = dict(
input_columns=["hypothesis", "premise"], input_columns=["hypothesis", "premise"],
...@@ -28,7 +28,7 @@ AX_g_infer_cfg = dict( ...@@ -28,7 +28,7 @@ AX_g_infer_cfg = dict(
AX_g_eval_cfg = dict( AX_g_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
AX_g_datasets = [ AX_g_datasets = [
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import CBDataset_V2 from opencompass.datasets import CBDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
CB_reader_cfg = dict( CB_reader_cfg = dict(
input_columns=["premise", "hypothesis"], input_columns=["premise", "hypothesis"],
...@@ -29,7 +29,7 @@ CB_infer_cfg = dict( ...@@ -29,7 +29,7 @@ CB_infer_cfg = dict(
CB_eval_cfg = dict( CB_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABC'),
) )
CB_datasets = [ CB_datasets = [
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import COPADataset_V2 from opencompass.datasets import COPADataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
COPA_reader_cfg = dict( COPA_reader_cfg = dict(
input_columns=["question", "premise", "choice1", "choice2"], input_columns=["question", "premise", "choice1", "choice2"],
...@@ -29,7 +29,7 @@ COPA_infer_cfg = dict( ...@@ -29,7 +29,7 @@ COPA_infer_cfg = dict(
COPA_eval_cfg = dict( COPA_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
COPA_datasets = [ COPA_datasets = [
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import MultiRCDataset_V2 from opencompass.datasets import MultiRCDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
MultiRC_reader_cfg = dict( MultiRC_reader_cfg = dict(
input_columns=["question", "text", "answer"], input_columns=["question", "text", "answer"],
...@@ -28,7 +28,7 @@ MultiRC_infer_cfg = dict( ...@@ -28,7 +28,7 @@ MultiRC_infer_cfg = dict(
MultiRC_eval_cfg = dict( MultiRC_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
MultiRC_datasets = [ MultiRC_datasets = [
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import AXDataset_V2 from opencompass.datasets import AXDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
RTE_reader_cfg = dict( RTE_reader_cfg = dict(
input_columns=["hypothesis", "premise"], input_columns=["hypothesis", "premise"],
...@@ -28,7 +28,7 @@ RTE_infer_cfg = dict( ...@@ -28,7 +28,7 @@ RTE_infer_cfg = dict(
RTE_eval_cfg = dict( RTE_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
RTE_datasets = [ RTE_datasets = [
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import AGIEvalDataset_v2, AGIEvalEvaluator from opencompass.datasets import AGIEvalDataset_v2, AGIEvalEvaluator
from opencompass.utils.text_postprocessors import first_capital_postprocess, first_capital_postprocess_multi from opencompass.utils.text_postprocessors import first_option_postprocess, first_capital_postprocess_multi
agieval_reader_cfg = dict( agieval_reader_cfg = dict(
input_columns=['question', 'options'], output_column='label') input_columns=['question', 'options'], output_column='label')
...@@ -76,14 +76,16 @@ for _name in agieval_single_choice_sets: ...@@ -76,14 +76,16 @@ for _name in agieval_single_choice_sets:
prompt_template=dict( prompt_template=dict(
type=PromptTemplate, type=PromptTemplate,
template=dict(round=[ template=dict(round=[
dict(role='HUMAN', prompt=f'{{question}}\n{{options}}\n{_hint}') dict(
role='HUMAN', prompt=f'{{question}}\n{{options}}\n{_hint}')
])), ])),
retriever=dict(type=ZeroRetriever), retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024)) inferencer=dict(type=GenInferencer, max_out_len=1024))
agieval_eval_cfg = dict( agieval_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_postprocessor=dict(type=first_capital_postprocess)) pred_postprocessor=dict(
type=first_option_postprocess, options='ABCDE'))
agieval_datasets.append( agieval_datasets.append(
dict( dict(
...@@ -105,7 +107,8 @@ for _name in agieval_multiple_choices_sets: ...@@ -105,7 +107,8 @@ for _name in agieval_multiple_choices_sets:
prompt_template=dict( prompt_template=dict(
type=PromptTemplate, type=PromptTemplate,
template=dict(round=[ template=dict(round=[
dict(role='HUMAN', prompt=f'{{question}}\n{{options}}\n{_hint}') dict(
role='HUMAN', prompt=f'{{question}}\n{{options}}\n{_hint}')
])), ])),
retriever=dict(type=ZeroRetriever), retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024)) inferencer=dict(type=GenInferencer, max_out_len=1024))
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import hellaswagDataset_V2 from opencompass.datasets import hellaswagDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
hellaswag_reader_cfg = dict( hellaswag_reader_cfg = dict(
input_columns=["ctx", "A", "B", "C", "D"], input_columns=["ctx", "A", "B", "C", "D"],
...@@ -16,8 +16,7 @@ hellaswag_infer_cfg = dict( ...@@ -16,8 +16,7 @@ hellaswag_infer_cfg = dict(
template=dict(round=[ template=dict(round=[
dict( dict(
role="HUMAN", role="HUMAN",
prompt=( prompt=("{ctx}\nQuestion: Which ending makes the most sense?\n"
"{ctx}\nQuestion: Which ending makes the most sense?\n"
"A. {A}\nB. {B}\nC. {C}\nD. {D}\n" "A. {A}\nB. {B}\nC. {C}\nD. {D}\n"
"You may choose from 'A', 'B', 'C', 'D'.\n" "You may choose from 'A', 'B', 'C', 'D'.\n"
"Answer:"), "Answer:"),
...@@ -31,7 +30,7 @@ hellaswag_infer_cfg = dict( ...@@ -31,7 +30,7 @@ hellaswag_infer_cfg = dict(
hellaswag_eval_cfg = dict( hellaswag_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
) )
hellaswag_datasets = [ hellaswag_datasets = [
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import OBQADataset from opencompass.datasets import OBQADataset
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
_input_columns = [ _input_columns = [
["question_stem", "A", "B", "C", "D"], ["question_stem", "A", "B", "C", "D"],
...@@ -14,14 +14,16 @@ _template = [ ...@@ -14,14 +14,16 @@ _template = [
round=[ round=[
dict( dict(
role="HUMAN", role="HUMAN",
prompt="Question: {question_stem}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:" prompt=
"Question: {question_stem}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:"
), ),
], ), ], ),
dict( dict(
round=[ round=[
dict( dict(
role="HUMAN", role="HUMAN",
prompt="Given the fact: {fact1}\nQuestion: {question_stem}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:", prompt=
"Given the fact: {fact1}\nQuestion: {question_stem}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:",
), ),
], ), ], ),
] ]
...@@ -46,16 +48,14 @@ for _i in range(2): ...@@ -46,16 +48,14 @@ for _i in range(2):
obqa_reader_cfg = dict( obqa_reader_cfg = dict(
input_columns=_input_columns[_i], output_column="answerKey") input_columns=_input_columns[_i], output_column="answerKey")
obqa_infer_cfg = dict( obqa_infer_cfg = dict(
prompt_template=dict( prompt_template=dict(type=PromptTemplate, template=_template[_i]),
type=PromptTemplate,
template=_template[_i]),
retriever=dict(type=ZeroRetriever), retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer), inferencer=dict(type=GenInferencer),
) )
obqa_eval_cfg = dict( obqa_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
) )
obqa_datasets[_i]["reader_cfg"] = obqa_reader_cfg obqa_datasets[_i]["reader_cfg"] = obqa_reader_cfg
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import piqaDataset_V2 from opencompass.datasets import piqaDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
piqa_reader_cfg = dict( piqa_reader_cfg = dict(
input_columns=["goal", "sol1", "sol2"], input_columns=["goal", "sol1", "sol2"],
...@@ -15,7 +15,9 @@ piqa_infer_cfg = dict( ...@@ -15,7 +15,9 @@ piqa_infer_cfg = dict(
type=PromptTemplate, type=PromptTemplate,
template=dict( template=dict(
round=[ round=[
dict(role="HUMAN", prompt="{goal}\nA. {sol1}\nB. {sol2}\nAnswer:") dict(
role="HUMAN",
prompt="{goal}\nA. {sol1}\nB. {sol2}\nAnswer:")
], ), ], ),
), ),
retriever=dict(type=ZeroRetriever), retriever=dict(type=ZeroRetriever),
...@@ -25,7 +27,7 @@ piqa_infer_cfg = dict( ...@@ -25,7 +27,7 @@ piqa_infer_cfg = dict(
piqa_eval_cfg = dict( piqa_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
piqa_datasets = [ piqa_datasets = [
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import RaceDataset from opencompass.datasets import RaceDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
race_reader_cfg = dict( race_reader_cfg = dict(
input_columns=['article', 'question', 'A', 'B', 'C', 'D'], input_columns=['article', 'question', 'A', 'B', 'C', 'D'],
...@@ -24,7 +24,7 @@ race_infer_cfg = dict( ...@@ -24,7 +24,7 @@ race_infer_cfg = dict(
race_eval_cfg = dict( race_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
pred_role='BOT') pred_role='BOT')
race_datasets = [ race_datasets = [
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import storyclozeDataset_V2 from opencompass.datasets import storyclozeDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
storycloze_reader_cfg = dict( storycloze_reader_cfg = dict(
input_columns=["context", "sentence_quiz1", "sentence_quiz2"], input_columns=["context", "sentence_quiz1", "sentence_quiz2"],
...@@ -28,7 +28,7 @@ storycloze_infer_cfg = dict( ...@@ -28,7 +28,7 @@ storycloze_infer_cfg = dict(
storycloze_eval_cfg = dict( storycloze_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
# The original story cloze dataset and repo are not long maintaining. # The original story cloze dataset and repo are not long maintaining.
......
...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever ...@@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import winograndeDataset_V2 from opencompass.datasets import winograndeDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
winogrande_reader_cfg = dict( winogrande_reader_cfg = dict(
input_columns=["opt1", "opt2"], input_columns=["opt1", "opt2"],
...@@ -28,7 +28,7 @@ winogrande_infer_cfg = dict( ...@@ -28,7 +28,7 @@ winogrande_infer_cfg = dict(
winogrande_eval_cfg = dict( winogrande_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
winogrande_datasets = [ winogrande_datasets = [
......
...@@ -48,6 +48,31 @@ def first_capital_postprocess(text: str) -> str: ...@@ -48,6 +48,31 @@ def first_capital_postprocess(text: str) -> str:
return '' return ''
def first_option_postprocess(text: str, options) -> str:
"""Find first valid option for text."""
patterns = [
f'[Tt]he answer is [{options}]',
f'[Tt]he correct answer is [{options}]',
f'答案是(.*?)[{options}]',
f'答案为(.*?)[{options}]',
f'固选(.*?)[{options}]',
f'答案应该是(.*?)[{options}]',
f'(\s|^)[{options}][\s。,,\.$]', # noqa
f'[{options}]',
]
regexes = [re.compile(pattern) for pattern in patterns]
for regex in regexes:
match = regex.search(text)
if match:
outputs = match.group(0)
for i in options:
if i in outputs:
return i
return ''
@TEXT_POSTPROCESSORS.register_module('first-capital-multi') @TEXT_POSTPROCESSORS.register_module('first-capital-multi')
def first_capital_postprocess_multi(text: str) -> str: def first_capital_postprocess_multi(text: str) -> str:
match = re.search(r'([A-D]+)', text) match = re.search(r'([A-D]+)', text)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment