Unverified Commit ad872a5d authored by Xiaoming Shi's avatar Xiaoming Shi Committed by GitHub
Browse files

[Feature] Update MedBench (#779)



* update medbench

* medbench update

* format medbench

* format

* Update

* update

* update

* update suffix

---------
Co-authored-by: default avatar施晓明 <PJLAB\shixiaoming@pjnl104220118l.pjlab.org>
Co-authored-by: default avatarLeymore <zfz-960727@163.com>
parent a74e4c1a
...@@ -6,7 +6,7 @@ exclude: | ...@@ -6,7 +6,7 @@ exclude: |
opencompass/openicl/icl_evaluator/hf_metrics/| opencompass/openicl/icl_evaluator/hf_metrics/|
opencompass/datasets/lawbench/utils| opencompass/datasets/lawbench/utils|
opencompass/datasets/lawbench/evaluation_functions/| opencompass/datasets/lawbench/evaluation_functions/|
opencompass/datasets/medbench| opencompass/datasets/medbench/|
docs/zh_cn/advanced_guides/compassbench_intro.md docs/zh_cn/advanced_guides/compassbench_intro.md
) )
repos: repos:
......
...@@ -2,41 +2,24 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate ...@@ -2,41 +2,24 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import ( from opencompass.datasets import MedBenchDataset, MedBenchEvaluator, MedBenchEvaluator_Cloze, MedBenchEvaluator_IE, MedBenchEvaluator_mcq, MedBenchEvaluator_CMeEE, MedBenchEvaluator_CMeIE, MedBenchEvaluator_CHIP_CDEE, MedBenchEvaluator_CHIP_CDN, MedBenchEvaluator_CHIP_CTC, MedBenchEvaluator_NLG, MedBenchEvaluator_TF, MedBenchEvaluator_DBMHG, MedBenchEvaluator_SMDoc, MedBenchEvaluator_IMCS_V2_MRG
MedBenchDataset,
MedBenchEvaluator,
MedBenchEvaluator_Cloze,
MedBenchEvaluator_IE,
MedBenchEvaluator_mcq,
MedBenchEvaluator_CMeEE,
MedBenchEvaluator_CMeIE,
MedBenchEvaluator_CHIP_CDEE,
MedBenchEvaluator_CHIP_CDN,
MedBenchEvaluator_CHIP_CTC,
MedBenchEvaluator_NLG,
MedBenchEvaluator_TF,
MedBenchEvaluator_EMR,
)
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_capital_postprocess
medbench_reader_cfg = dict( medbench_reader_cfg = dict(
input_columns=['problem_input'], output_column='label') input_columns=['problem_input'], output_column='label')
medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题,用acc判断 medbench_multiple_choices_sets = ['Med-Exam', 'DDx-basic', 'DDx-advanced', 'SafetyBench'] # 选择题,用acc判断
medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答 medbench_qa_sets = ['MedHC', 'MedMC', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答
medbench_cloze_sets = ['Triage'] # 限定域QA,有标答 medbench_cloze_sets = ['MedHG'] # 限定域QA,有标答
medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答 medbench_single_choice_sets = ['DrugCA'] # 正确与否判断,有标答
medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致,用F1评价 medbench_ie_sets = ['DBMHG', 'CMeEE', 'CMeIE', 'CHIP-CDEE', 'CHIP-CDN', 'CHIP-CTC', 'SMDoc', 'IMCS-V2-MRG'] # 判断识别的实体是否一致,用F1评价
#, 'CMeIE', 'CHIP_CDEE', 'CHIP_CDN', 'CHIP_CTC', 'Doc_parsing', 'MRG'
medbench_datasets = [] medbench_datasets = []
for name in medbench_single_choice_sets: for name in medbench_single_choice_sets:
medbench_infer_cfg = dict( medbench_infer_cfg = dict(
prompt_template=dict( prompt_template=dict(
...@@ -144,7 +127,7 @@ for name in medbench_ie_sets: ...@@ -144,7 +127,7 @@ for name in medbench_ie_sets:
inferencer=dict(type=GenInferencer)) inferencer=dict(type=GenInferencer))
medbench_eval_cfg = dict( medbench_eval_cfg = dict(
evaluator=dict(type=eval('MedBenchEvaluator_'+name)), pred_role="BOT") evaluator=dict(type=eval('MedBenchEvaluator_'+name.replace('-', '_'))), pred_role="BOT")
medbench_datasets.append( medbench_datasets.append(
dict( dict(
......
...@@ -11,31 +11,31 @@ from .constructions import ChatGPTSchema, ResultsForHumanSchema ...@@ -11,31 +11,31 @@ from .constructions import ChatGPTSchema, ResultsForHumanSchema
from .utils import extract_answer, read_jsonl, save_jsonl from .utils import extract_answer, read_jsonl, save_jsonl
# define the datasets # define the datasets
medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题,用acc判断 medbench_multiple_choices_sets = ['Med-Exam', 'DDx-basic', 'DDx-advanced', 'DDx-advanced', 'SafetyBench'] # 选择题,用acc判断
medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答 medbench_qa_sets = ['MedHC', 'MedMC', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答
medbench_cloze_sets = ['Triage'] # 限定域QA,有标答 medbench_cloze_sets = ['MedHG'] # 限定域QA,有标答
medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答 medbench_single_choice_sets = ['DrugCA'] # 正确与否判断,有标答
medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致,用F1评价 medbench_ie_sets = ['DBMHG', 'CMeEE', 'CMeIE', 'CHIP-CDEE', 'CHIP-CDN', 'CHIP-CTC', 'SMDoc', 'IMCS-V2-MRG'] # 判断识别的实体是否一致,用F1评价
def convert_zero_shot(line, dataset_name): def convert_zero_shot(line, dataset_name):
# passage = line['passage'] if line['passage'] is not None else '' # passage = line['passage'] if line['passage'] is not None else ''
if dataset_name in medbench_qa_sets: # if dataset_name in medbench_qa_sets:
return line['question'] # return line['question']
elif dataset_name in medbench_cloze_sets: # elif dataset_name in medbench_cloze_sets:
return '问题:' + line['question'] + '\n答案:' # return '问题:' + line['question'] + '\n答案:'
elif dataset_name in medbench_multiple_choices_sets: # elif dataset_name in medbench_multiple_choices_sets:
return '问题:' + line['question'] + ' ' \ # return '问题:' + line['question'] + ' ' \
+ '选项:' + ' '.join(line['options']) + '\n从A到G,我们应该选择' # + '选项:' + ' '.join(line['options']) + '\n从A到G,我们应该选择'
else: # else:
# return line['question']
return line['question'] return line['question']
prefix = '该问题为单选题,所有选项中必有一个正确答案,且只有一个正确答案。\n' prefix = '该问题为单选题,所有选项中必有一个正确答案,且只有一个正确答案。\n'
# def convert_zero_shot_CoT_stage1(line, dataset_name): # def convert_zero_shot_CoT_stage1(line, dataset_name):
# try: # try:
# passage = line['passage'] if line['passage'] is not None else '' # passage = line['passage'] if line['passage'] is not None else ''
......
This diff is collapsed.
...@@ -82,10 +82,8 @@ class MedBenchEvaluator(BaseEvaluator): ...@@ -82,10 +82,8 @@ class MedBenchEvaluator(BaseEvaluator):
detail['correct'] = True detail['correct'] = True
details.append(detail) details.append(detail)
score = cnt / len(predictions) * 100 score = cnt / len(predictions) * 100
#输出字典类型 {'score':'', 'details'}
return {'Accuracy': score, 'details': details} return {'Accuracy': score, 'details': details}
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class MedBenchEvaluator_mcq(BaseEvaluator): class MedBenchEvaluator_mcq(BaseEvaluator):
...@@ -109,16 +107,18 @@ class MedBenchEvaluator_mcq(BaseEvaluator): ...@@ -109,16 +107,18 @@ class MedBenchEvaluator_mcq(BaseEvaluator):
return {'score': score, 'details': details} return {'score': score, 'details': details}
def process_generated_results_CMeEE(pred_file): def process_generated_results_CMeEE(pred_file):
# 实体每类占一行,每行格式为 "[类型名称]实体:实体名称1,实体名称2,实体名称3\n"
# 多个实体,用 ,符号分割
structured_output = [] structured_output = []
answer_choices = ['药物', '设备', '医院科室', '微生物类', '身体部位', '医疗操作', '医学检验项目', '症状', '疾病'] answer_choices = ['药物', '设备', '医院科室', '微生物类', '身体部位', '医疗操作', '医学检验项目', '症状', '疾病']
for pred in pred_file: for pred in pred_file:
list_entities = [] list_entities = []
for choice in answer_choices: for choice in answer_choices:
for piece in re.split('[,|.|。|;|\n]', pred): for piece in re.split('[。|;|\n]', pred):
if piece.startswith(f"{choice}"): if piece.startswith(f"{choice}"):
mentions = piece.replace(f"{choice}实体为", "").replace(f"{choice}实体是", "").replace(f"{choice}实体:", "").split(",") mentions = piece.replace(f"{choice}实体为", "").replace(f"{choice}实体是", "").replace(f"{choice}实体:", "").replace(f'{choice}:', '').replace(f'{choice}:', '').split(",")
for ment in mentions: for ment in mentions:
list_entities.append({'entity':ment, 'type':choice}) list_entities.append({'type':choice, 'entity':ment})
structured_output.append(list_entities) structured_output.append(list_entities)
return structured_output return structured_output
...@@ -128,12 +128,15 @@ def process_generated_results_EMR(pred_file): ...@@ -128,12 +128,15 @@ def process_generated_results_EMR(pred_file):
for pred in pred_file: for pred in pred_file:
list_entities = [] list_entities = []
for choice in answer_choices: for choice in answer_choices:
for piece in re.split('[,|.|?|;|,|。|;|\n]', pred): for piece in re.split('\n', pred):
if piece.startswith(f"{choice}"): # if piece.startswith(f"{choice}"):
mentions = piece.replace(f"{choice}:", "").split(",") if f"{choice}" in piece and len(piece.split(f"{choice}:"))>1:
mentions = [w.strip() for w in mentions if len(w.strip()) > 0] # mentions = piece.replace(f"{choice}:", "").split(",")
for ment in mentions: mentions = piece.split(f"{choice}:")[1].strip()
list_entities.append({ment: choice}) # mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
list_entities.append({choice: mentions})
# for ment in mentions:
# list_entities.append({choice: ment})
structured_output.append(list_entities) structured_output.append(list_entities)
return structured_output return structured_output
...@@ -156,13 +159,16 @@ def process_generated_results_CMeIE(pred_file): ...@@ -156,13 +159,16 @@ def process_generated_results_CMeIE(pred_file):
# 首先是解析出label: # 首先是解析出label:
predicate = line.split("关系的头尾实体对")[0][2: ].strip() predicate = line.split("关系的头尾实体对")[0][2: ].strip()
line = line.replace(f"具有{predicate}关系的头尾实体对如下:", "") line = line.replace(f"具有{predicate}关系的头尾实体对如下:", "")
for spo_str in line.split("。"): # for spo_str in line.split("。"):
if len(spo_str.split(",尾实体为")) < 2: for spo_str in re.split(';|。', line):
if len(spo_str.split(",尾实体:")) < 2:
continue continue
head_mention_str, tail_mention_str = spo_str.split(",尾实体为")[:2] head_mention_str, tail_mention_str = spo_str.split(",尾实体:")[:2]
head_mention_str = head_mention_str.replace("头实体为", "").strip()
tail_mention_str = tail_mention_str.replace("尾实体为", "").strip() head_mention_str = head_mention_str.replace("头实体:", "").strip()
tail_mention_str = tail_mention_str.replace("尾实体:", "").strip()
list_spos.append( list_spos.append(
{ {
...@@ -176,7 +182,7 @@ def process_generated_results_CMeIE(pred_file): ...@@ -176,7 +182,7 @@ def process_generated_results_CMeIE(pred_file):
def process_generated_results_CDN(pred_file): def process_generated_results_CDN(pred_file):
structured_output = [] structured_output = []
answer_choices = json.load(open('./data/MedBench/CHIP_CDN/CHIP-CDN_entity.json', 'r')) answer_choices = json.load(open('./opencompass/datasets/medbench/entity_list.jsonl', 'r'))
for line in pred_file: for line in pred_file:
gen_output = line gen_output = line
...@@ -211,15 +217,17 @@ def process_generated_results_CDEE(pred_file): ...@@ -211,15 +217,17 @@ def process_generated_results_CDEE(pred_file):
keys = ["主体词", "发生状态", "描述词", "解剖部位"] keys = ["主体词", "发生状态", "描述词", "解剖部位"]
list_answer_strs = gen_output.split("\n") list_answer_strs = gen_output.split("\n")
# list_answer_strs: ['主题词:饮食,描述词:差;', '主题词:消瘦']
list_events = [] list_events = []
for ans_str in list_answer_strs: for ans_str in list_answer_strs:
if '主体词' in ans_str: if '主体词' in ans_str:
event_info = {} event_info = {}
ans_attrs = ans_str.split(";") ans_attrs = ans_str.split(",")
for a_attr in ans_attrs: for a_attr in ans_attrs:
for key in keys: for key in keys:
if a_attr.startswith(f"{key}:"): if a_attr.startswith(f"{key}:"):
a_attr = a_attr.replace(f"{key}:", "").strip() a_attr = a_attr.replace(f"{key}:", "").strip().strip(';')
if key in ["描述词", "解剖部位"]: if key in ["描述词", "解剖部位"]:
a_attr_split = a_attr.split(",") a_attr_split = a_attr.split(",")
a_attr_split = [w.strip() for w in a_attr_split if len(w.strip()) > 0] a_attr_split = [w.strip() for w in a_attr_split if len(w.strip()) > 0]
...@@ -239,7 +247,7 @@ def process_generated_results_CDEE(pred_file): ...@@ -239,7 +247,7 @@ def process_generated_results_CDEE(pred_file):
structured_output.append(list_events) structured_output.append(list_events)
return structured_output return structured_output
def process_generated_results_CTC(pred_file, task_dataset): def process_generated_results_CTC(pred_file):
structured_output = [] structured_output = []
for line in pred_file: for line in pred_file:
...@@ -252,60 +260,60 @@ def process_generated_results_CTC(pred_file, task_dataset): ...@@ -252,60 +260,60 @@ def process_generated_results_CTC(pred_file, task_dataset):
def process_generated_results_doc_parsing(pred_file): def process_generated_results_doc_parsing(pred_file):
output = [] output = []
for line in pred_file: for line in pred_file:
structured_output = {'体温':'', '脉搏':'', '心率':'', '收缩压':'', '舒张压':'', '呼吸':'', '上腹部深压痛':'', '腹部反跳痛':'', '上腹部肿块':''} structured_output = []
sentence_list = line.strip().split(',|。|\n') sentence_list = line.strip().split('\n')
for sentence in sentence_list: for sentence in sentence_list:
if '体温' in sentence: if '体温' in sentence:
temp_value = re.search('[0-9]+', sentence) temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value: if temp_value:
structured_output['体温'] = temp_value.group(0) structured_output.append({'type':'体温', 'entity':temp_value.group(0)})
else: else:
structured_output['体温'] = '未扪及' structured_output.append({'type':'体温', 'entity':'未扪及'})
elif '脉搏' in sentence: elif '脉搏' in sentence:
temp_value = re.search('[0-9]+', sentence) temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value: if temp_value:
structured_output['脉搏'] = temp_value.group(0) structured_output.append({'type':'脉搏', 'entity':temp_value.group(0)})
else: else:
structured_output['脉搏'] = '未扪及' structured_output.append({'type':'脉搏', 'entity':'未扪及'})
elif '心率' in sentence: elif '心率' in sentence:
temp_value = re.search('[0-9]+', sentence) temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value: if temp_value:
structured_output['心率'] = temp_value.group(0) structured_output.append({'type':'心率', 'entity':temp_value.group(0)})
else: else:
structured_output['心率'] = '未扪及' structured_output.append({'type':'心率', 'entity':'未扪及'})
elif '收缩压' in sentence: elif '收缩压' in sentence:
temp_value = re.search('[0-9]+', sentence) temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value: if temp_value:
structured_output['收缩压'] = temp_value.group(0) structured_output.append({'type':'收缩压', 'entity':temp_value.group(0)})
else: else:
structured_output['收缩压'] = '未扪及' structured_output.append({'type':'收缩压', 'entity':'未扪及'})
elif '舒张压' in sentence: elif '舒张压' in sentence:
temp_value = re.search('[0-9]+', sentence) temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value: if temp_value:
structured_output['舒张压'] = temp_value.group(0) structured_output.append({'type':'舒张压', 'entity':temp_value.group(0)})
else: else:
structured_output['舒张压'] = '未扪及' structured_output.append({'type':'舒张压', 'entity':'未扪及'})
elif '呼吸' in sentence: elif '呼吸' in sentence:
temp_value = re.search('[0-9]+', sentence) temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value: if temp_value:
structured_output['呼吸'] = temp_value.group(0) structured_output.append({'type':'呼吸', 'entity':temp_value.group(0)})
else: else:
structured_output['呼吸'] = '未扪及' structured_output.append({'type':'呼吸', 'entity':'未扪及'})
elif '上腹部深压痛' in sentence: elif '上腹部深压痛' in sentence:
if re.search('是|存在|有', sentence): if re.search('未|不|没|无', sentence):
structured_output['是否上腹部深压痛'] = '是' structured_output.append({'type':'上腹部深压痛', 'entity':'否是'})
else: else:
structured_output['是否上腹部深压痛'] = '否' structured_output.append({'type':'上腹部深压痛', 'entity':'是'})
elif '腹部反跳痛' in sentence: elif '腹部反跳痛' in sentence:
if re.search('是|存在|有', sentence): if re.search('未|不|没|无', sentence):
structured_output['是否腹部反跳痛'] = '是' structured_output.append({'type':'腹部反跳痛', 'entity':'否'})
else: else:
structured_output['是否腹部反跳痛'] = '否' structured_output.append({'type':'腹部反跳痛', 'entity':'是'})
elif '上腹部肿块' in sentence: elif '上腹部肿块' in sentence:
if re.search('是|存在|有', sentence): if re.search('未|不|没|无', sentence):
structured_output['上腹部肿块'] = '扪及' structured_output.append({'type':'上腹部肿块', 'entity':'未扪及'})
else: else:
structured_output['上腹部肿块'] = '未扪及' structured_output.append({'type':'上腹部肿块', 'entity':'扪及'})
output.append(structured_output) output.append(structured_output)
return output return output
...@@ -315,18 +323,22 @@ def process_generated_results_mrg(pred_file): ...@@ -315,18 +323,22 @@ def process_generated_results_mrg(pred_file):
for pred in pred_file: for pred in pred_file:
list_entities = [] list_entities = []
for choice in answer_choices: for choice in answer_choices:
for piece in re.split('[,|.|?|;|,|。|;|\n]', pred): if '\n\n' in pred['answer']:
if piece.startswith(f"{choice}实体"): for piece in re.split('\n\n', pred['answer']):
mentions = piece.replace(f"{choice}实体:", "").split(",") if f"{choice}" in piece and len(piece.split(f"{choice}:"))>1:
mentions = piece.split(f"{choice}:")[1].strip()
list_entities.append({choice:mentions})
else:
for piece in re.split('\n', pred):
if piece.startswith(f"{choice}:"):
mentions = piece.replace(f"{choice}:", "").split(",")
mentions = [w.strip() for w in mentions if len(w.strip()) > 0] mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
for ment in mentions: for ment in mentions:
list_entities.append({ment: choice}) list_entities.append({choice:ment})
structured_output.append(list_entities) structured_output.append(list_entities)
return structured_output return structured_output
def calc_info_extract_task_scores(list_structured_predict, list_structured_golden):
def calc_info_extract_task_scores(list_structured_golden,
list_structured_predict):
assert len(list_structured_golden) == len(list_structured_predict) assert len(list_structured_golden) == len(list_structured_predict)
...@@ -334,12 +346,11 @@ def calc_info_extract_task_scores(list_structured_golden, ...@@ -334,12 +346,11 @@ def calc_info_extract_task_scores(list_structured_golden,
fp = 0 fp = 0
fn = 0 fn = 0
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict): for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
# samp_golden: [[{}]]
answer_golden = samp_golden answer_golden = samp_golden
answer_predict = samp_predict answer_predict = samp_predict
# assert isinstance(answer_golden, list)
assert isinstance(answer_golden, list) # assert isinstance(answer_predict, list), "sample format is wrong!"
assert isinstance(answer_predict, list), "sample format is wrong!"
set_golden = set() set_golden = set()
for inst in answer_golden: for inst in answer_golden:
...@@ -356,18 +367,11 @@ def calc_info_extract_task_scores(list_structured_golden, ...@@ -356,18 +367,11 @@ def calc_info_extract_task_scores(list_structured_golden,
for inst in answer_predict: for inst in answer_predict:
assert isinstance(inst, dict) assert isinstance(inst, dict)
keys = sorted(list(inst.keys())) keys = sorted(list(inst.keys()))
# inst = tuple([inst[w] for w in keys])
inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys])
# inst = list(inst.items()) inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys])
# inst.sort()
# inst = tuple(inst)
set_predict.add(inst) set_predict.add(inst)
# print("set_predict: ", set_predict)
# print("set_golden: ", set_golden)
tp += len(set_golden.intersection(set_predict)) tp += len(set_golden.intersection(set_predict))
fp += len(set_predict.difference(set_golden)) fp += len(set_predict.difference(set_golden))
fn += len(set_golden.difference(set_predict)) fn += len(set_golden.difference(set_predict))
...@@ -402,7 +406,9 @@ def calc_cls_task_scores(list_structured_golden, ...@@ -402,7 +406,9 @@ def calc_cls_task_scores(list_structured_golden,
pred_label = pred_samp pred_label = pred_samp
gt_label = gt_samp gt_label = gt_samp
assert gt_label != "" # assert gt_label != ""
if gt_label == "":
get_label = list_labels[0]
if pred_label == "": if pred_label == "":
pred_label = list_labels[0] pred_label = list_labels[0]
...@@ -434,16 +440,10 @@ def calc_nlg_task_scores(list_structured_golden, list_structured_predict): ...@@ -434,16 +440,10 @@ def calc_nlg_task_scores(list_structured_golden, list_structured_predict):
references = [] references = []
details = [] details = []
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict): for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
# print("samp_golden: ", samp_golden)
# print("samp_predict: ", samp_predict)
# assert samp_golden["sample_id"] == samp_predict["sample_id"], "sample ordering is wrong!"
answer_golden = samp_golden answer_golden = samp_golden
answer_predict = samp_predict answer_predict = samp_predict
print('#')
print(answer_golden)
print(answer_predict)
if not (answer_predict and answer_golden): if not (answer_predict and answer_golden):
continue continue
...@@ -456,8 +456,6 @@ def calc_nlg_task_scores(list_structured_golden, list_structured_predict): ...@@ -456,8 +456,6 @@ def calc_nlg_task_scores(list_structured_golden, list_structured_predict):
answer_golden = "无 。" answer_golden = "无 。"
if answer_predict.strip() == "": if answer_predict.strip() == "":
answer_predict = "无 。" answer_predict = "无 。"
# print("answer_predict: ", answer_predict)
# print("answer_golden: ", answer_golden)
predictions.append(answer_predict) predictions.append(answer_predict)
references.append(answer_golden) references.append(answer_golden)
...@@ -542,14 +540,14 @@ class MedBenchEvaluator_CMeEE(BaseEvaluator): ...@@ -542,14 +540,14 @@ class MedBenchEvaluator_CMeEE(BaseEvaluator):
return calc_scores_f1(predictions, references) return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class MedBenchEvaluator_EMR(BaseEvaluator): class MedBenchEvaluator_DBMHG(BaseEvaluator):
def score(self, predictions, references): def score(self, predictions, references):
predictions = process_generated_results_EMR(predictions) predictions = process_generated_results_EMR(predictions)
return calc_scores_f1(predictions, references) return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class MedBenchEvaluator_MRG(BaseEvaluator): class MedBenchEvaluator_IMCS_V2_MRG(BaseEvaluator):
def score(self, predictions, references): def score(self, predictions, references):
predictions = process_generated_results_mrg(predictions) predictions = process_generated_results_mrg(predictions)
...@@ -581,10 +579,10 @@ class MedBenchEvaluator_CHIP_CTC(BaseEvaluator): ...@@ -581,10 +579,10 @@ class MedBenchEvaluator_CHIP_CTC(BaseEvaluator):
def score(self, predictions, references): def score(self, predictions, references):
predictions = process_generated_results_CTC(predictions) predictions = process_generated_results_CTC(predictions)
return calc_scores_ctc(predictions, references)[0] return calc_scores_ctc(predictions, references)
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class MedBenchEvaluator_Doc_parsing(BaseEvaluator): class MedBenchEvaluator_SMDoc(BaseEvaluator):
def score(self, predictions, references): def score(self, predictions, references):
predictions = process_generated_results_doc_parsing(predictions) predictions = process_generated_results_doc_parsing(predictions)
...@@ -594,23 +592,36 @@ class MedBenchEvaluator_Doc_parsing(BaseEvaluator): ...@@ -594,23 +592,36 @@ class MedBenchEvaluator_Doc_parsing(BaseEvaluator):
class MedBenchEvaluator_NLG(BaseEvaluator): class MedBenchEvaluator_NLG(BaseEvaluator):
def score(self, predictions, references): def score(self, predictions, references):
# predictions = process_generated_results_med(predictions)
return calc_scores_nlg(predictions, references) return calc_scores_nlg(predictions, references)
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class MedBenchEvaluator_Cloze(BaseEvaluator): class MedBenchEvaluator_Cloze(BaseEvaluator):
def score(self, predictions, references): def score(self, predictions, references):
# predictions: [[]] erke_list = ["血管外科", "临床心理科", "生殖医学中心", "肿瘤科", "妇科", "小儿风湿免疫科", "放射科", "小儿内分泌代谢科", "急诊科", "心血管内科", "小儿神经内科", "感染科", "整形外科", "全科医学科", "泌尿外科", "皮肤科", "消化内科", "口腔科", "小儿心脏中心", "产科", "血液内科", "小儿普外科", "小儿泌尿外科", "小儿感染科", "临床营养科", "小儿骨科", "发育行为儿童保健科", "小儿呼吸内科", "神经外科", "内分泌代谢科", "普外科", "肛肠外科", "小儿神经外科", "康复医学科", "骨科", "风湿免疫科", "小儿内科", "眼科", "心胸外科", "小儿肾脏内科", "乳腺外科", "小儿血液肿瘤科", "体检中心", "神经内科", "耳鼻咽喉头颈外科", "小儿消化内科", "呼吸内科", "核医学科", "肾脏内科"]
# references: [[]] no_erke_list = ["血管外科", "临床心理科", "生殖医学中心", "肿瘤科", "妇科", "放射科", "急诊科", "心血管内科", "感染科", "整形外科", "全科医学科", "泌尿外科", "皮肤科", "消化内科", "口腔科", "产科", "血液内科", "临床营养科", "神经外科", "内分泌代谢科", "普外科", "肛肠外科", "康复医学科", "骨科", "风湿免疫科", "眼科", "心胸外科", "乳腺外科", "体检中心", "神经内科", "耳鼻咽喉头颈外科", "呼吸内科", "核医学科", "肾脏内科"]
# predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
cross_erke_list = [item for item in erke_list if '小儿' in item and item.replace('小儿', '') in no_erke_list]
cross_list = [item[2:] for item in cross_erke_list]
details = [] details = []
cnt = 0 cnt = 0
for pred, ref in zip(predictions, references): for pred, ref in zip(predictions, references):
detail = {'pred':pred, 'answer':ref, 'correct':False} detail = {'pred':pred, 'answer':ref, 'correct':False}
current_pred = []
if sum([item in pred for item in ref]) == len(ref): for x in cross_list:
if '小儿' + x in predictions:
current_pred.append('小儿' + x)
elif x in predictions:
current_pred.append(x)
for x in (set(erke_list + no_erke_list) - set(cross_erke_list) - set(cross_list)):
if x in predictions:
current_pred.append(x)
# if set([x for x in erke_list + no_erke_list if x in pred]) == set(ref):
if set(current_pred) == set(ref):
cnt += 1 cnt += 1
detail['correct'] = True detail['correct'] = True
details.append(detail) details.append(detail)
......
...@@ -148,8 +148,8 @@ def parse_math_answer(setting_name, raw_string): ...@@ -148,8 +148,8 @@ def parse_math_answer(setting_name, raw_string):
last_match = None last_match = None
if '=' in s: if '=' in s:
last_match = s.split('=')[-1].lstrip(' ').rstrip('.') last_match = s.split('=')[-1].lstrip(' ').rstrip('.')
if '\\n' in last_match: if '\n' in last_match:
last_match = last_match.split('\\n')[0] last_match = last_match.split('\n')[0]
else: else:
pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])' pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])'
matches = re.findall(pattern, s) matches = re.findall(pattern, s)
...@@ -170,6 +170,8 @@ def parse_math_answer(setting_name, raw_string): ...@@ -170,6 +170,8 @@ def parse_math_answer(setting_name, raw_string):
def parse_qa_multiple_answer(string): def parse_qa_multiple_answer(string):
# if setting_name == 'few-shot-CoT': # if setting_name == 'few-shot-CoT':
# string = extract_last_line(string) # string = extract_last_line(string)
for x in ['CC', 'CA', 'AC', 'POMES', 'AI', 'MIBG', 'CF', 'CTE', 'AD', 'CB', 'BG', 'BD', 'BE', 'BH', 'CTB', 'BI', 'CE', 'Pugh', 'Child', 'CTI', 'CTA', 'TACE', 'PPD', 'Castleman', 'BA', 'CH', 'AB', 'CTC', 'CT', 'CTH', 'CD', 'AH', 'AE', 'AA', 'AF', 'BC', 'CG', 'BB', 'CI', 'BF', 'CTF', 'CTG', 'AG', 'CTD', '分级C', '分级A', 'I131', '分级B', '分级D', '131I‐MIBG', 'NYHA', 'IPF', 'DIP', 'Lambert-Eaton', 'Graves', 'IIA期', 'CKD', 'FDA', 'A级', 'B级', 'C级', 'D级', '维生素D']:
string = string.replace(x, '')
pattern = '\(*([A-Z])\)*' pattern = '\(*([A-Z])\)*'
match = re.findall(pattern, string) match = re.findall(pattern, string)
if match: if match:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment