Unverified Commit 35aace77 authored by Xiaoming Shi's avatar Xiaoming Shi Committed by GitHub
Browse files

[Fix] Update MedBench (#845)

parent 8ed022b4
......@@ -2,13 +2,13 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import MedBenchDataset, MedBenchEvaluator, MedBenchEvaluator_Cloze, MedBenchEvaluator_IE, MedBenchEvaluator_mcq, MedBenchEvaluator_CMeEE, MedBenchEvaluator_CMeIE, MedBenchEvaluator_CHIP_CDEE, MedBenchEvaluator_CHIP_CDN, MedBenchEvaluator_CHIP_CTC, MedBenchEvaluator_NLG, MedBenchEvaluator_TF, MedBenchEvaluator_DBMHG, MedBenchEvaluator_SMDoc, MedBenchEvaluator_IMCS_V2_MRG
from opencompass.datasets import MedBenchDataset, MedBenchEvaluator, MedBenchEvaluator_Cloze, MedBenchEvaluator_CMeEE, MedBenchEvaluator_CMeIE, MedBenchEvaluator_CHIP_CDEE, MedBenchEvaluator_CHIP_CDN, MedBenchEvaluator_CHIP_CTC, MedBenchEvaluator_NLG, MedBenchEvaluator_TF, MedBenchEvaluator_DBMHG, MedBenchEvaluator_SMDoc, MedBenchEvaluator_IMCS_V2_MRG
from opencompass.utils.text_postprocessors import first_capital_postprocess
medbench_reader_cfg = dict(
input_columns=['problem_input'], output_column='label')
medbench_multiple_choices_sets = ['Med-Exam', 'DDx-basic', 'DDx-advanced', 'SafetyBench'] # 选择题,用acc判断
medbench_multiple_choices_sets = ['Med-Exam', 'DDx-basic', 'DDx-advanced', 'MedSafety'] # 选择题,用acc判断
medbench_qa_sets = ['MedHC', 'MedMC', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答
......@@ -20,31 +20,7 @@ medbench_ie_sets = ['DBMHG', 'CMeEE', 'CMeIE', 'CHIP-CDEE', 'CHIP-CDN', 'CHIP-CT
medbench_datasets = []
for name in medbench_single_choice_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[dict(role="HUMAN", prompt='{problem_input}')])),
retriever=dict(type=ZeroRetriever
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
inferencer=dict(type=GenInferencer))
medbench_eval_cfg = dict(
evaluator=dict(type=MedBenchEvaluator_TF), pred_role="BOT")
medbench_datasets.append(
dict(
type=MedBenchDataset,
path='./data/MedBench/' + name,
name=name,
abbr='medbench-' + name,
setting_name='zero-shot',
reader_cfg=medbench_reader_cfg,
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))
for name in medbench_multiple_choices_sets:
for name in medbench_single_choice_sets + medbench_multiple_choices_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
......
......@@ -12,7 +12,7 @@ from .post_process import parse_math_answer, parse_qa_multiple_answer
import evaluate
from nltk.translate.bleu_score import sentence_bleu
# from bert_score import score
# # from bert_score import score
import re
from transformers import BasicTokenizer
from rouge_chinese import Rouge
......@@ -39,33 +39,6 @@ class MedBenchDataset(BaseDataset):
return dataset
@LOAD_DATASET.register_module()
class MedBenchDataset_v2(BaseDataset):
@staticmethod
def load(path: str, name: str, setting_name: str):
assert setting_name in 'zero-shot', 'only support zero-shot setting'
filename = osp.join(path, name + '.jsonl')
with open(filename, encoding='utf-8') as f:
data = [json.loads(line.strip()) for line in f]
dataset = []
for item in data:
passage = item['passage'] if item['passage'] else ''
question = passage + item['question']
options = '\n'.join(item['options']) if item['options'] else ''
if item['label']:
if isinstance(item['label'], list):
label = ''.join(item['label'])
else:
label = item['label']
else:
label = item['answer']
d = {'question': question, 'options': options, 'label': label}
dataset.append(d)
dataset = Dataset.from_list(dataset)
return dataset
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator(BaseEvaluator):
......@@ -84,28 +57,6 @@ class MedBenchEvaluator(BaseEvaluator):
score = cnt / len(predictions) * 100
return {'Accuracy': score, 'details': details}
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_mcq(BaseEvaluator):
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
details = []
cnt = 0
for pred, ref in zip(predictions, references):
detail = {'pred': pred, 'answer': ref, 'correct': False}
if pred == ref:
cnt += 1
detail['correct'] = True
details.append(detail)
score = cnt / len(predictions) * 100
return {'score': score, 'details': details}
def process_generated_results_CMeEE(pred_file):
# 实体每类占一行,每行格式为 "[类型名称]实体:实体名称1,实体名称2,实体名称3\n"
# 多个实体,用 ,符号分割
......@@ -114,9 +65,9 @@ def process_generated_results_CMeEE(pred_file):
for pred in pred_file:
list_entities = []
for choice in answer_choices:
for piece in re.split('[。|;|\n]', pred):
for piece in re.split('\n', pred):
if piece.startswith(f"{choice}"):
mentions = piece.replace(f"{choice}实体为", "").replace(f"{choice}实体是", "").replace(f"{choice}实体:", "").replace(f'{choice}', '').replace(f'{choice}:', '').split(",")
mentions = re.split(r"[,,]", piece.replace(f"{choice}:", "").replace(f"{choice}", ""))
for ment in mentions:
list_entities.append({'type':choice, 'entity':ment})
structured_output.append(list_entities)
......@@ -124,59 +75,41 @@ def process_generated_results_CMeEE(pred_file):
def process_generated_results_EMR(pred_file):
structured_output = []
answer_choices = ['主诉', '现病史', '既往史', '个人史', '婚育史', '家族史']
for pred in pred_file:
list_entities = []
for choice in answer_choices:
for piece in re.split('\n', pred):
# if piece.startswith(f"{choice}"):
if f"{choice}" in piece and len(piece.split(f"{choice}:"))>1:
# mentions = piece.replace(f"{choice}:", "").split(",")
mentions = piece.split(f"{choice}:")[1].strip()
# mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
list_entities.append({choice: mentions})
# for ment in mentions:
# list_entities.append({choice: ment})
structured_output.append(list_entities)
regex = r"^(主诉|现病史|既往史|个人史|婚育史|家族史)[::]([\s\S]+)$"
for prediction in pred_file:
entities: dict = {}
if "\n\n" in prediction:
blocks = prediction.split("\n\n")
else:
blocks = prediction.splitlines()
for line in blocks:
if match := re.match(regex, line.strip()):
type_ = match[1]
mention = match[2].strip()
entities[type_] = mention
structured_output.append(entities)
return structured_output
def process_generated_results_CMeIE(pred_file):
structured_output = []
for line in pred_file:
gen_output = line
# 答案格式:
# 每个关系类型占一行,格式为
# "具有{lab}关系的头尾实体对如下:头实体为str,尾实体为str;头实体为str,尾实体为str;"
answer_choices = "相关(导致)、鉴别诊断、遗传因素、发病性别倾向、相关(症状)、手术治疗、预防、辅助检查、筛查、阶段、临床表现、风险评估因素、同义词、发病年龄、预后生存率、病史、传播途径、治疗后症状、药物治疗、辅助治疗、化疗、死亡率、放射治疗、病因、组织学检查、内窥镜检查、多发群体、并发症、实验室检查、就诊科室、病理生理、高危因素、发病率、多发地区、病理分型、影像学检查、转移部位、发病部位、相关(转化)、外侵部位、预后状况、发病机制、多发季节"
answer_choices = answer_choices.split('、')
re_choices = "|".join(re.escape(choice) for choice in answer_choices.split('、'))
regex = (
rf'关系[::]["“]({re_choices})["”][,,]'
r'头实体[::]["“]([^"”]+)["”][,,]尾实体[::]["“]([^"”]+)["”]'
)
list_spos = []
assert isinstance(answer_choices, list)
list_answer_strs = gen_output.split("\n")
for line in list_answer_strs:
# 首先是解析出label:
predicate = line.split("关系的头尾实体对")[0][2: ].strip()
line = line.replace(f"具有{predicate}关系的头尾实体对如下:", "")
# for spo_str in line.split("。"):
for spo_str in re.split(';|。', line):
if len(spo_str.split(",尾实体:")) < 2:
continue
head_mention_str, tail_mention_str = spo_str.split(",尾实体:")[:2]
head_mention_str = head_mention_str.replace("头实体:", "").strip()
tail_mention_str = tail_mention_str.replace("尾实体:", "").strip()
for item in re.finditer(regex, line):
print(item)
for match in re.finditer(regex, line):
list_spos.append({"predicate": match[1], "subject": match[2], "object": match[3]})
list_spos.append(
{
"predicate": predicate,
"subject": head_mention_str,
"object": tail_mention_str,
}
)
structured_output.append(list_spos)
return structured_output
......@@ -185,9 +118,6 @@ def process_generated_results_CDN(pred_file):
answer_choices = json.load(open('./opencompass/datasets/medbench/entity_list.jsonl', 'r'))
for line in pred_file:
gen_output = line
# 答案格式:
# 多个选中的标准化实体,用 , 符号分割
answer_str = gen_output.split("\n")[-1]
answers = answer_str.split(",")
......@@ -206,45 +136,31 @@ def process_generated_results_CDN(pred_file):
return structured_output
def process_generated_results_CDEE(pred_file):
structured_output = []
for line in pred_file:
gen_output = line
# 答案格式:
# 第一行:引导词
# 每个事件占一行,事件字段用 ; 分隔, 然后每个字段是 字段名:字段值的格式"
# 字段值有多个,则用 ,符号分隔
keys = ["主体词", "发生状态", "描述词", "解剖部位"]
list_answer_strs = gen_output.split("\n")
# list_answer_strs: ['主题词:饮食,描述词:差;', '主题词:消瘦']
list_events = []
for ans_str in list_answer_strs:
if '主体词' in ans_str:
event_info = {}
ans_attrs = ans_str.split(",")
for a_attr in ans_attrs:
for key in keys:
if a_attr.startswith(f"{key}:"):
a_attr = a_attr.replace(f"{key}:", "").strip().strip(';')
if key in ["描述词", "解剖部位"]:
a_attr_split = a_attr.split(",")
a_attr_split = [w.strip() for w in a_attr_split if len(w.strip()) > 0]
event_info[key] = a_attr_split
else:
event_info[key] = a_attr
for key in keys:
if key not in event_info:
if key in ["描述词", "解剖部位"]:
event_info[key] = []
else:
event_info[key] = ""
list_events.append(event_info)
structured_output.append(list_events)
for prediction in pred_file:
events: list[dict] = []
for line in prediction.splitlines():
if "主体词" in line:
line = line.rstrip("。")
kvs = line.split(";")
kv_dict = dict(kv.split(":", maxsplit=1) for kv in kvs if ":" in kv)
events.append({
"主体词": kv_dict.get("主体词", ""),
"发生状态": (
v
if (v := kv_dict.get("发生状态", "不确定")) in ("不确定", "否定")
else ""
),
"描述词": (
v.split(",") if (v := kv_dict.get("描述词", "空")) != "空" else []
),
"解剖部位": (
v.split(",")
if (v := kv_dict.get("解剖部位", "空")) != "空"
else []
),
})
structured_output.append(events)
return structured_output
def process_generated_results_CTC(pred_file):
......@@ -258,84 +174,99 @@ def process_generated_results_CTC(pred_file):
return structured_output
def process_generated_results_doc_parsing(pred_file):
float_field_regex = r"(体温|脉搏|心率|收缩压|舒张压|呼吸)[^\d]*(\d+(?:\.\d+)?)"
output = []
for line in pred_file:
structured_output = []
sentence_list = line.strip().split('\n')
for sentence in sentence_list:
if '体温' in sentence:
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output.append({'type':'体温', 'entity':temp_value.group(0)})
else:
structured_output.append({'type':'体温', 'entity':'未扪及'})
elif '脉搏' in sentence:
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output.append({'type':'脉搏', 'entity':temp_value.group(0)})
else:
structured_output.append({'type':'脉搏', 'entity':'未扪及'})
elif '心率' in sentence:
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output.append({'type':'心率', 'entity':temp_value.group(0)})
for prediction in pred_file:
entities = {
"体温": "未扪及",
"脉搏": "未扪及",
"心率": "未扪及",
"收缩压": "未扪及",
"舒张压": "未扪及",
"呼吸": "未扪及",
"是否上腹部深压痛": None,
"是否腹部反跳痛": None,
"上腹部肿块": None,
}
for sentence in re.split("[,|。|\n]", prediction):
for match in re.finditer(float_field_regex, prediction):
entities[match[1]] = match[2]
if "上腹部深压痛" in sentence:
if re.search("是(?!否)|(?:^|[^不])存在|有", sentence):
entities["是否上腹部深压痛"] = "是"
else:
structured_output.append({'type':'心率', 'entity':'未扪及'})
elif '收缩压' in sentence:
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output.append({'type':'收缩压', 'entity':temp_value.group(0)})
entities["是否上腹部深压痛"] = "否"
elif "腹部反跳痛" in sentence:
if re.search("是(?!否)|(?:^|[^不])存在|有", sentence):
entities["是否腹部反跳痛"] = "是"
else:
structured_output.append({'type':'收缩压', 'entity':'未扪及'})
elif '舒张压' in sentence:
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output.append({'type':'舒张压', 'entity':temp_value.group(0)})
entities["是否腹部反跳痛"] = "否"
elif "上腹部肿块" in sentence:
if re.search("是(?!否)|(?:^|[^不])存在|有", sentence):
entities["上腹部肿块"] = "扪及"
else:
structured_output.append({'type':'舒张压', 'entity':'未扪及'})
elif '呼吸' in sentence:
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output.append({'type':'呼吸', 'entity':temp_value.group(0)})
else:
structured_output.append({'type':'呼吸', 'entity':'未扪及'})
elif '上腹部深压痛' in sentence:
if re.search('未|不|没|无', sentence):
structured_output.append({'type':'上腹部深压痛', 'entity':'否是'})
else:
structured_output.append({'type':'上腹部深压痛', 'entity':'是'})
elif '腹部反跳痛' in sentence:
if re.search('未|不|没|无', sentence):
structured_output.append({'type':'腹部反跳痛', 'entity':'否'})
else:
structured_output.append({'type':'腹部反跳痛', 'entity':'是'})
elif '上腹部肿块' in sentence:
if re.search('未|不|没|无', sentence):
structured_output.append({'type':'上腹部肿块', 'entity':'未扪及'})
else:
structured_output.append({'type':'上腹部肿块', 'entity':'扪及'})
output.append(structured_output)
entities["上腹部肿块"] = "未扪及"
result = [
{
"type": "体温(℃)",
"entity": entities["体温"],
},
{
"type": "脉搏(次/分)",
"entity": entities["脉搏"],
},
{
"type": "心率(次/分)",
"entity": entities["心率"],
},
{
"type": "收缩压(mmHg)",
"entity": entities["收缩压"],
},
{
"type": "舒张压(mmHg)",
"entity": entities["舒张压"],
},
{
"type": "呼吸(次/分)",
"entity": entities["呼吸"],
},
]
if entities["是否上腹部深压痛"]:
result.append({
"type": "是否上腹部深压痛",
"entity": entities["是否上腹部深压痛"],
})
if entities["是否腹部反跳痛"]:
result.append({
"type": "是否腹部反跳痛",
"entity": entities["是否腹部反跳痛"],
})
if entities["上腹部肿块"]:
result.append({
"type": "上腹部肿块",
"entity": entities["上腹部肿块"],
})
output.append(result)
return output
def process_generated_results_mrg(pred_file):
structured_output = []
answer_choices = ['主诉', '现病史', '既往史', '辅助检查', '诊断']
for pred in pred_file:
list_entities = []
for choice in answer_choices:
if '\n\n' in pred['answer']:
for piece in re.split('\n\n', pred['answer']):
if f"{choice}" in piece and len(piece.split(f"{choice}:"))>1:
mentions = piece.split(f"{choice}:")[1].strip()
list_entities.append({choice:mentions})
else:
for piece in re.split('\n', pred):
if piece.startswith(f"{choice}:"):
mentions = piece.replace(f"{choice}:", "").split(",")
mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
for ment in mentions:
list_entities.append({choice:ment})
structured_output.append(list_entities)
regex = r"^(主诉|现病史|辅助检查|既往史|诊断|建议)[::]([\s\S]+)$"
for prediction in pred_file:
entities = {}
if "\n\n" in prediction:
blocks = prediction.split("\n\n")
else:
blocks = prediction.splitlines()
for line in blocks:
if match := re.match(regex, line.strip()):
type_ = match[1]
mention = match[2].strip()
entities[type_] = mention
structured_output.append(entities)
return structured_output
def calc_info_extract_task_scores(list_structured_predict, list_structured_golden):
......@@ -550,8 +481,14 @@ class MedBenchEvaluator_DBMHG(BaseEvaluator):
class MedBenchEvaluator_IMCS_V2_MRG(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_mrg(predictions)
return calc_scores_f1(predictions, references)
# predictions = process_generated_results_mrg(predictions)
references_revise = []
for item in references:
temp_ref = ''
for sub_item in item:
temp_ref += sub_item['type'] + ':' + sub_item['entity'] + '\n'
references_revise.append(temp_ref)
return calc_nlg_task_scores(references_revise, predictions)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CMeIE(BaseEvaluator):
......@@ -582,46 +519,39 @@ class MedBenchEvaluator_CHIP_CTC(BaseEvaluator):
return calc_scores_ctc(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_SMDoc(BaseEvaluator):
class MedBenchEvaluator_Doc_parsing(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_doc_parsing(predictions)
return calc_scores_f1(predictions, references)
# predictions = process_generated_results_doc_parsing(predictions)
references_revise = []
for item in references:
temp_ref = ''
for sub_item in item:
temp_ref += sub_item['type'] + ':' + sub_item['entity'] + '\n'
references_revise.append(temp_ref)
return calc_nlg_task_scores(references_revise, predictions)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_NLG(BaseEvaluator):
def score(self, predictions, references):
# predictions = process_generated_results_med(predictions)
return calc_scores_nlg(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_Cloze(BaseEvaluator):
def score(self, predictions, references):
erke_list = ["血管外科", "临床心理科", "生殖医学中心", "肿瘤科", "妇科", "小儿风湿免疫科", "放射科", "小儿内分泌代谢科", "急诊科", "心血管内科", "小儿神经内科", "感染科", "整形外科", "全科医学科", "泌尿外科", "皮肤科", "消化内科", "口腔科", "小儿心脏中心", "产科", "血液内科", "小儿普外科", "小儿泌尿外科", "小儿感染科", "临床营养科", "小儿骨科", "发育行为儿童保健科", "小儿呼吸内科", "神经外科", "内分泌代谢科", "普外科", "肛肠外科", "小儿神经外科", "康复医学科", "骨科", "风湿免疫科", "小儿内科", "眼科", "心胸外科", "小儿肾脏内科", "乳腺外科", "小儿血液肿瘤科", "体检中心", "神经内科", "耳鼻咽喉头颈外科", "小儿消化内科", "呼吸内科", "核医学科", "肾脏内科"]
no_erke_list = ["血管外科", "临床心理科", "生殖医学中心", "肿瘤科", "妇科", "放射科", "急诊科", "心血管内科", "感染科", "整形外科", "全科医学科", "泌尿外科", "皮肤科", "消化内科", "口腔科", "产科", "血液内科", "临床营养科", "神经外科", "内分泌代谢科", "普外科", "肛肠外科", "康复医学科", "骨科", "风湿免疫科", "眼科", "心胸外科", "乳腺外科", "体检中心", "神经内科", "耳鼻咽喉头颈外科", "呼吸内科", "核医学科", "肾脏内科"]
cross_erke_list = [item for item in erke_list if '小儿' in item and item.replace('小儿', '') in no_erke_list]
cross_list = [item[2:] for item in cross_erke_list]
# predictions: [[]]
# references: [[]]
# predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
details = []
cnt = 0
for pred, ref in zip(predictions, references):
detail = {'pred':pred, 'answer':ref, 'correct':False}
current_pred = []
for x in cross_list:
if '小儿' + x in predictions:
current_pred.append('小儿' + x)
elif x in predictions:
current_pred.append(x)
for x in (set(erke_list + no_erke_list) - set(cross_erke_list) - set(cross_list)):
if x in predictions:
current_pred.append(x)
# if set([x for x in erke_list + no_erke_list if x in pred]) == set(ref):
if set(current_pred) == set(ref):
if sum([item in pred for item in ref]) == len(ref):
cnt += 1
detail['correct'] = True
details.append(detail)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment