Unverified Commit 7c9fbcf8 authored by PabloAgustin's avatar PabloAgustin Committed by GitHub
Browse files

New healthcare benchmark: careqa (#2714)



* New healthcare benchmark: careqa

* LAUNCH_MN5_ACC <python main.py --config config/mn5.yml --models Llama-3.2-1B-Instruct --tasks careqa_open --num_fewshot 0>

* Add fixes, READMES, and remove task_list.txt

* pre-commit passed, add formatting updates; add nanmean agg_metric

* Fix import error.

* Wrapped imports in try excepts

* Wrapped imports in try excepts; also metrics to catch bert_score import error

* Try except to catch ImportErrors as well

* use np.nan

* pre-commit

---------
Co-authored-by: default avatarPabloAgustin <pablo.martin@bsc.es>
Co-authored-by: default avatarBaber <baber@hey.com>
parent 2c8ffb80
include: meddialog_raw_dialogues.yaml
task: meddialog_raw_perplexity
output_type: loglikelihood_rolling
doc_to_text: ""
process_results: !function utils_perplexity.process_results_raw
metric_list:
- metric: word_perplexity
higher_is_better: false
- metric: byte_perplexity
higher_is_better: false
- metric: bits_per_byte
higher_is_better: false
metadata:
version: 1.0
import numpy as np
try:
import evaluate
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
bleurt = evaluate.load("bleurt", "bleurt-base-512", module_type="metric")
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError(
"Please install evaluation metrics via pip install evaluate and pip install bert-score",
)
except Exception as e:
raise RuntimeError(
f"Error loading evaluation metrics: {str(e)}. Please check your installation."
)
def doc_eval(pred, refs):
try:
bleu_results = bleu.compute(predictions=pred, references=refs)
except Exception as e:
print(f"Bleu error: {e}")
bleu_results = {"bleu": np.nan}
try:
rouge_results = rouge.compute(predictions=pred, references=refs)
except Exception as e:
print(f"Rouge error: {e}")
rouge_results = {"rouge1": np.nan, "rouge2": np.nan, "rougeL": np.nan}
try:
bleurt_scores = bleurt.compute(predictions=pred, references=refs)["scores"]
except Exception as e:
print(f"Bleurt error: {e}")
bleurt_scores = [np.nan]
try:
bert_scores = bertscore.compute(predictions=pred, references=refs, lang="en")[
"f1"
]
except Exception as e:
print(f"Bert error: {e}")
bert_scores = [np.nan]
if bleu_results["bleu"] == 0:
# Sometimes bleu is 0.0 and this breaks the stderr computation.
bleu_results["bleu"] += 1e-5
results = {
"bleu": bleu_results["bleu"],
"rouge1": rouge_results["rouge1"],
"rouge2": rouge_results["rouge2"],
"rougeL": rouge_results["rougeL"],
"bleurt": np.mean(bleurt_scores),
"bert_score": np.mean(bert_scores),
}
return results
def doc_to_text_raw(doc) -> str:
return doc["description"]
def doc_to_target_raw(doc) -> str:
return doc["utterances"]["utterance"][1]
def process_results_gen_raw(doc, results):
pred, refs = [results[0]], [doc_to_target_raw(doc)]
if len(refs[0]) < 1 or len(pred[0]) < 1:
return {
"bleu": np.nan,
"rouge1": np.nan,
"rouge2": np.nan,
"rougeL": np.nan,
"bleurt": np.nan,
"bert_score": np.nan,
}
results = doc_eval(pred, refs)
return {
"bleu": results["bleu"],
"rouge1": results["rouge1"],
"rouge2": results["rouge2"],
"rougeL": results["rougeL"],
"bleurt": results["bleurt"],
"bert_score": results["bert_score"],
}
def doc_to_text_qsumm(doc) -> str:
return doc["src"]
def doc_to_target_qsumm(doc) -> str:
return doc["tgt"]
def process_results_gen_qsumm(doc, results):
pred, refs = [results[0]], [doc_to_target_qsumm(doc)]
if len(refs[0]) < 1 or len(pred[0]) < 1:
return {
"bleu": np.nan,
"rouge1": np.nan,
"rouge2": np.nan,
"rougeL": np.nan,
"bleurt": np.nan,
"bert_score": np.nan,
}
results = doc_eval(pred, refs)
return {
"bleu": results["bleu"],
"rouge1": results["rouge1"],
"rouge2": results["rouge2"],
"rougeL": results["rougeL"],
"bleurt": results["bleurt"],
"bert_score": results["bert_score"],
}
import re
from lm_eval.tasks.meddialog.utils import doc_to_target_qsumm, doc_to_target_raw
def process_results_qsumm(doc, results):
(loglikelihood,) = results
_words = len(re.split(r"\s+", doc_to_target_qsumm(doc)))
_bytes = len(doc_to_target_qsumm(doc).encode("utf-8"))
return {
"word_perplexity": (loglikelihood, _words),
"byte_perplexity": (loglikelihood, _bytes),
"bits_per_byte": (loglikelihood, _bytes),
}
def process_results_raw(doc, results):
(loglikelihood,) = results
_words = len(re.split(r"\s+", doc_to_target_raw(doc)))
_bytes = len(doc_to_target_raw(doc).encode("utf-8"))
return {
"word_perplexity": (loglikelihood, _words),
"byte_perplexity": (loglikelihood, _bytes),
"bits_per_byte": (loglikelihood, _bytes),
}
# MEDIQA_QA 2019
### Paper
Title: `Overview of the MEDIQA 2019 Shared Task on Textual Inference, Question Entailment and Question Answering`
Abstract: [https://aclanthology.org/W19-5039/](https://aclanthology.org/W19-5039/)
Open-ended medical Question-Answering stemming from the MEDIQA 2019 open challenge.
Homepage: \
[https://sites.google.com/view/mediqa2019](https://sites.google.com/view/mediqa2019)
#### Tasks
* `mediqa_qa2019`: Open-ended QA in english.
* `mediqa_qa2019_perplexity`: Open-Ended QA in english, evaluated with perplexity.
### Citation
```bibtex
@inproceedings{ben-abacha-etal-2019-overview,
title = "Overview of the {MEDIQA} 2019 Shared Task on Textual Inference, Question Entailment and Question Answering",
author = "Ben Abacha, Asma and
Shivade, Chaitanya and
Demner-Fushman, Dina",
editor = "Demner-Fushman, Dina and
Cohen, Kevin Bretonnel and
Ananiadou, Sophia and
Tsujii, Junichi",
booktitle = "Proceedings of the 18th BioNLP Workshop and Shared Task",
month = aug,
year = "2019",
address = "Florence, Italy",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/W19-5039/",
doi = "10.18653/v1/W19-5039",
pages = "370--379",
abstract = "This paper presents the MEDIQA 2019 shared task organized at the ACL-BioNLP workshop. The shared task is motivated by a need to develop relevant methods, techniques and gold standards for inference and entailment in the medical domain, and their application to improve domain specific information retrieval and question answering systems. MEDIQA 2019 includes three tasks: Natural Language Inference (NLI), Recognizing Question Entailment (RQE), and Question Answering (QA) in the medical domain. 72 teams participated in the challenge, achieving an accuracy of 98{\%} in the NLI task, 74.9{\%} in the RQE task, and 78.3{\%} in the QA task. In this paper, we describe the tasks, the datasets, and the participants' approaches and results. We hope that this shared task will attract further research efforts in textual inference, question entailment, and question answering in the medical domain."
}
```
task: mediqa_qa2019
dataset_path: bigbio/mediqa_qa
description: >
Instructions: The following text is a question asked by a patient. Answer how a doctor would, while trying to be as informative and helpful as possible.
output_type: generate_until
training_split: train_live_qa_med
validation_split: validation
test_split: test
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results_gen
generation_kwargs:
until:
- "\n\n"
metric_list:
- metric: bleu
aggregation: nanmean
higher_is_better: true
- metric: rouge1
aggregation: nanmean
higher_is_better: true
- metric: rouge2
aggregation: nanmean
higher_is_better: true
- metric: rougeL
aggregation: nanmean
higher_is_better: true
- metric: bleurt
aggregation: nanmean
higher_is_better: true
- metric: bert_score
aggregation: nanmean
higher_is_better: true
metadata:
version: 1.0
task: mediqa_qa2019_perplexity
dataset_path: bigbio/mediqa_qa
description: >
Instructions: The following text is a question asked by a patient. Answer how a doctor would, while trying to be as informative and helpful as possible.
output_type: loglikelihood_rolling
training_split: train_live_qa_med
validation_split: validation
test_split: test
doc_to_text: ""
doc_to_target: !function utils_perplexity.doc_to_target
process_results: !function utils_perplexity.process_results
should_decontaminate: true
doc_to_decontamination_query: !function utils_perplexity.doc_to_target
metric_list:
- metric: perplexity
higher_is_better: false
- metric: word_perplexity
higher_is_better: false
- metric: byte_perplexity
higher_is_better: false
- metric: bits_per_byte
higher_is_better: false
metadata:
version: 1.0
dataset_kwargs:
trust_remote_code: true
import numpy as np
try:
import evaluate
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
bleurt = evaluate.load("bleurt", "bleurt-base-512", module_type="metric")
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError(
"Please install evaluation metrics via pip install evaluate and pip install bert-score",
)
except Exception as e:
raise RuntimeError(
f"Error loading evaluation metrics: {str(e)}. Please check your installation."
)
def doc_eval(pred, refs):
try:
bleu_results = bleu.compute(predictions=pred, references=refs)
except Exception as e:
print(f"Bleu error: {e}")
bleu_results = {"bleu": np.NAN}
try:
rouge_results = rouge.compute(predictions=pred, references=refs)
except Exception as e:
print(f"Rouge error: {e}")
rouge_results = {"rouge1": np.NAN, "rouge2": np.NAN, "rougeL": np.NAN}
try:
bleurt_scores = bleurt.compute(predictions=pred, references=refs)["scores"]
except Exception as e:
print(f"Bleurt error: {e}")
bleurt_scores = [np.NAN]
try:
bert_scores = bertscore.compute(predictions=pred, references=refs, lang="en")[
"f1"
]
except Exception as e:
print(f"Bert error: {e}")
bert_scores = [np.NAN]
if bleu_results["bleu"] == 0:
# Sometimes bleu is 0.0 and this breaks the stderr computation.
bleu_results["bleu"] += 1e-5
results = {
"bleu": bleu_results["bleu"],
"rouge1": rouge_results["rouge1"],
"rouge2": rouge_results["rouge2"],
"rougeL": rouge_results["rougeL"],
"bleurt": np.mean(bleurt_scores),
"bert_score": np.mean(bert_scores),
}
return results
def doc_to_text(doc) -> str:
return doc["QUESTION"]["QuestionText"]
def doc_to_target(doc) -> str:
return doc["QUESTION"]["AnswerList"][0]["Answer"]["AnswerText"]
def process_results_gen(doc, results):
pred, refs = [results[0]], [doc_to_target(doc)]
if len(refs[0]) < 1 or len(pred[0]) < 1:
return {
"bleu": np.NAN,
"rouge1": np.NAN,
"rouge2": np.NAN,
"rougeL": np.NAN,
"bleurt": np.NAN,
"bert_score": np.NAN,
}
results = doc_eval(pred, refs)
return {
"bleu": results["bleu"],
"rouge1": results["rouge1"],
"rouge2": results["rouge2"],
"rougeL": results["rougeL"],
"bleurt": results["bleurt"],
"bert_score": results["bert_score"],
}
import math
import re
def doc_to_target(doc) -> str:
return doc["QUESTION"]["AnswerList"][0]["Answer"]["AnswerText"]
def process_results(doc, results):
(loglikelihood,) = results
_words = len(re.split(r"\s+", doc_to_target(doc)))
_bytes = len(doc_to_target(doc).encode("utf-8"))
print(f"perplexity: {math.exp(-loglikelihood / _words)}")
return {
"word_perplexity": (loglikelihood, _words),
"byte_perplexity": (loglikelihood, _bytes),
"bits_per_byte": (loglikelihood, _bytes),
"perplexity": (loglikelihood),
}
# MedText
### Paper
Title: `Towards Automatic Generation of Shareable Synthetic Clinical Notes Using Neural Language Models`
Abstract: [https://arxiv.org/abs/1905.07002](https://arxiv.org/abs/1905.07002)
MedText is a medical diagnosis dataset containing over 1000 top notch textbook
quality patient presentations and diagnosis/treatments. The 100 most common diseases
and the 30 most common injuries people go to the hospital with, are, among others,
fully captured in the dataset, with multiple datapoints for each ranging from mild
to complicated to severe.
#### Tasks
* `medtext`: Open-ended QA in english.
* `medtext_perplexity`: Open-ended QA in english, evaluated with perplexity.
### Citation
```bibtex
@misc{melamud2019automaticgenerationshareablesynthetic,
title={Towards Automatic Generation of Shareable Synthetic Clinical Notes Using Neural Language Models},
author={Oren Melamud and Chaitanya Shivade},
year={2019},
eprint={1905.07002},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/1905.07002},
}
```
task: medtext
dataset_path: BI55/MedText
description: >
Instructions: The following text is from a collection of medical records. What follows is the patients record. Answer how a doctor would, what is the likely diagnosis, and what is the treatment?. Answer how a doctor would, what is the likely diagnosis, and what is the treatment?
output_type: generate_until
training_split: train
validation_split: train
test_split: train
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
generation_kwargs:
until:
- "\n\n"
metric_list:
- metric: bleu
aggregation: nanmean
higher_is_better: true
- metric: rouge1
aggregation: nanmean
higher_is_better: true
- metric: rouge2
aggregation: nanmean
higher_is_better: true
- metric: rougeL
aggregation: nanmean
higher_is_better: true
- metric: bleurt
aggregation: nanmean
higher_is_better: true
- metric: bert_score
aggregation: nanmean
higher_is_better: true
metadata:
version: 1.0
include: medtext.yaml
task: medtext_perplexity
output_type: loglikelihood_rolling
doc_to_text: ""
process_results: !function utils_perplexity.process_results
metric_list:
- metric: word_perplexity
higher_is_better: false
- metric: byte_perplexity
higher_is_better: false
- metric: bits_per_byte
higher_is_better: false
metadata:
version: 1.0
import numpy as np
try:
import evaluate
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
bleurt = evaluate.load("bleurt", "bleurt-base-512", module_type="metric")
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError(
"Please install evaluation metrics via pip install evaluate and pip install bert-score",
)
except Exception as e:
raise RuntimeError(
f"Error loading evaluation metrics: {str(e)}. Please check your installation."
)
def doc_eval(pred, refs):
try:
bleu_results = bleu.compute(predictions=pred, references=refs)
except Exception as e:
print(f"Bleu error: {e}")
bleu_results = {"bleu": np.NAN}
try:
rouge_results = rouge.compute(predictions=pred, references=refs)
except Exception as e:
print(f"Rouge error: {e}")
rouge_results = {"rouge1": np.NAN, "rouge2": np.NAN, "rougeL": np.NAN}
try:
bleurt_scores = bleurt.compute(predictions=pred, references=refs)["scores"]
except Exception as e:
print(f"Bleurt error: {e}")
bleurt_scores = [np.NAN]
try:
bert_scores = bertscore.compute(predictions=pred, references=refs, lang="en")[
"f1"
]
except Exception as e:
print(f"Bert error: {e}")
bert_scores = [np.NAN]
if bleu_results["bleu"] == 0:
# Sometimes bleu is 0.0 and this breaks the stderr computation.
bleu_results["bleu"] += 1e-5
results = {
"bleu": bleu_results["bleu"],
"rouge1": rouge_results["rouge1"],
"rouge2": rouge_results["rouge2"],
"rougeL": rouge_results["rougeL"],
"bleurt": np.mean(bleurt_scores),
"bert_score": np.mean(bert_scores),
}
return results
def doc_to_text(doc) -> str:
return doc["Prompt"]
def doc_to_target(doc) -> str:
return doc["Completion"]
def process_results(doc, results):
pred, refs = [results[0]], [doc_to_target(doc)]
if len(refs[0]) < 1 or len(pred[0]) < 1:
return {
"bleu": np.NAN,
"rouge1": np.NAN,
"rouge2": np.NAN,
"rougeL": np.NAN,
"bleurt": np.NAN,
"bert_score": np.NAN,
}
results = doc_eval(pred, refs)
return {
"bleu": results["bleu"],
"rouge1": results["rouge1"],
"rouge2": results["rouge2"],
"rougeL": results["rougeL"],
"bleurt": results["bleurt"],
"bert_score": results["bert_score"],
}
import re
from lm_eval.tasks.medtext.utils import doc_to_target
def process_results(doc, results):
(loglikelihood,) = results
_words = len(re.split(r"\s+", doc_to_target(doc)))
_bytes = len(doc_to_target(doc).encode("utf-8"))
return {
"word_perplexity": (loglikelihood, _words),
"byte_perplexity": (loglikelihood, _bytes),
"bits_per_byte": (loglikelihood, _bytes),
}
# MeqSum
### Paper
Title: `On the Summarization of Consumer Health Questions`
Abstract: [https://aclanthology.org/P19-1215/](https://aclanthology.org/P19-1215/)
Question understanding is one of the main challenges in question answering. In real world
applications, users often submit natural language questions that are longer than needed
and include peripheral information that increases the complexity of the question, leading
to substantially more false positives in answer retrieval. In this paper, we study neural
abstractive models for medical question summarization. We introduce the MeQSum corpus of
1,000 summarized consumer health questions.
### Citation
```bibtex
@inproceedings{ben-abacha-demner-fushman-2019-summarization,
title = "On the Summarization of Consumer Health Questions",
author = "Ben Abacha, Asma and
Demner-Fushman, Dina",
booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2019",
address = "Florence, Italy",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/P19-1215",
doi = "10.18653/v1/P19-1215",
pages = "2228--2234"}
```
task: meqsum
dataset_path: bigbio/meqsum
dataset_name: meqsum_source
description: >
Instructions: The following text is contains a medical question. Extract and summarize the question.
output_type: generate_until
training_split: train
validation_split: train
test_split: train
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results_gen
generation_kwargs:
until:
- "\n\n"
metric_list:
- metric: bleu
aggregation: nanmean
higher_is_better: true
- metric: rouge1
aggregation: nanmean
higher_is_better: true
- metric: rouge2
aggregation: nanmean
higher_is_better: true
- metric: rougeL
aggregation: nanmean
higher_is_better: true
- metric: bert_score
aggregation: nanmean
higher_is_better: true
- metric: bleurt
aggregation: nanmean
higher_is_better: true
metadata:
version: 1.0
import numpy as np
try:
import evaluate
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
bleurt = evaluate.load("bleurt", "bleurt-base-512", module_type="metric")
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError(
"Please install evaluation metrics via pip install evaluate and pip install bert-score",
)
except Exception as e:
raise RuntimeError(
f"Error loading evaluation metrics: {str(e)}. Please check your installation."
)
def doc_to_text(doc) -> str:
text = doc["CHQ"]
idx = text.find("MESSAGE")
if idx != -1:
return text[idx + 9 :]
else:
return text
def doc_to_target(doc) -> str:
return doc["Summary"]
def process_results_gen(doc, results):
pred, refs = [results[0]], [doc_to_target(doc)]
if len(refs[0]) < 1 or len(pred[0]) < 1:
return {
"bleu": np.NAN,
"rouge1": np.NAN,
"rouge2": np.NAN,
"rougeL": np.NAN,
"bleurt": np.NAN,
"bert_score": np.NAN,
}
try:
bleu_results = bleu.compute(predictions=pred, references=refs)
except Exception as e:
print(f"Bleu error: {e}")
bleu_results = {"bleu": np.NAN}
try:
rouge_results = rouge.compute(predictions=pred, references=refs)
except Exception as e:
print(f"Rouge error: {e}")
rouge_results = {"rouge1": np.NAN, "rouge2": np.NAN, "rougeL": np.NAN}
try:
bleurt_scores = bleurt.compute(predictions=pred, references=refs)["scores"]
except Exception as e:
print(f"Bleurt error: {e}")
bleurt_scores = [np.NAN]
try:
bert_scores = bertscore.compute(predictions=pred, references=refs, lang="en")[
"f1"
]
except Exception as e:
print(f"Bert error: {e}")
bert_scores = [np.NAN]
if bleu_results["bleu"] == 0:
# Sometimes bleu is 0.0 and this breaks the stderr computation.
bleu_results["bleu"] += 1e-5
return {
"bleu": bleu_results["bleu"],
"rouge1": rouge_results["rouge1"],
"rouge2": rouge_results["rouge2"],
"rougeL": rouge_results["rougeL"],
"bleurt": np.mean(bleurt_scores),
"bert_score": np.mean(bert_scores),
}
# MIMIC-III Report Summarization
### Paper
Title: `MIMIC-III, a freely accessible critical care database`
Abstract: [https://www.nature.com/articles/sdata201635](https://www.nature.com/articles/sdata201635)
MIMIC-III containins de-identified health data from around 40,000 patients admitted to
intensive care units at a large tertiary care hospital. This task focuses on radiology
report summarization.
#### Tasks
* `mimic_repsum`: Generate extractive notes summaries, evaluated with [Radgraph-F1](https://www.cell.com/patterns/fulltext/S2666-3899(23)00157-5), bleu, rouge, bert_score, bleurt.
* `mimic_repsum_perplexity`: Generate extractive notes summaries, evaluated with perplexity.
### Citation
```bibtex
@article{johnson2016mimic,
title={MIMIC-III, a freely accessible critical care database},
author={Johnson, Alistair EW and Pollard, Tom J and Shen, Lu and Lehman, Li-wei H and Feng, Mengling and Ghassemi, Mohammad and Moody, Benjamin and Szolovits, Peter and Anthony Celi, Leo and Mark, Roger G},
journal={Scientific data},
volume={3},
number={1},
pages={1--9},
year={2016},
publisher={Nature Publishing Group}
}
```
task: mimic_repsum
dataset_path: dmacres/mimiciii-hospitalcourse-meta
description: >
Instructions: The following text is from a collection of medical records. Summarize the findings into diagnostic statements. Do not omit relevant information and avoid using abbreviations or jargon unless they appear in the original text.
output_type: generate_until
training_split: train
validation_split: validation
test_split: test
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
generation_kwargs:
until:
- "\n\n"
top_p: 0.95
metric_list:
- metric: bleu
aggregation: nanmean
higher_is_better: true
- metric: rouge1
aggregation: nanmean
higher_is_better: true
- metric: rouge2
aggregation: nanmean
higher_is_better: true
- metric: rougeL
aggregation: nanmean
higher_is_better: true
- metric: bleurt
aggregation: nanmean
higher_is_better: true
- metric: bert_score
aggregation: nanmean
higher_is_better: true
- metric: F1-Radgraph
aggregation: nanmean
higher_is_better: true
metadata:
version: 1.4
include: mimic_repsum.yaml
task: mimic_repsum_perplexity
output_type: loglikelihood_rolling
doc_to_text: ""
process_results: !function utils_perplexity.process_results
metric_list:
- metric: word_perplexity
higher_is_better: false
- metric: byte_perplexity
higher_is_better: false
- metric: bits_per_byte
higher_is_better: false
import re
from collections.abc import Iterable
import numpy as np
try:
import evaluate
from radgraph import F1RadGraph
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
bleurt = evaluate.load("bleurt", "bleurt-base-512", module_type="metric")
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError(
"Please install evaluation metrics via pip install evaluate and pip install bert-score",
)
except Exception as e:
raise RuntimeError(
f"Error loading evaluation metrics: {str(e)}. Please check your installation."
)
def doc_eval(pred, refs):
try:
bleu_results = bleu.compute(predictions=pred, references=refs)
except Exception as e:
print(f"Bleu error: {e}")
bleu_results = {"bleu": np.NAN}
try:
rouge_results = rouge.compute(predictions=pred, references=refs)
except Exception as e:
print(f"Rouge error: {e}")
rouge_results = {"rouge1": np.NAN, "rouge2": np.NAN, "rougeL": np.NAN}
try:
bleurt_scores = bleurt.compute(predictions=pred, references=refs)["scores"]
except Exception as e:
print(f"Bleurt error: {e}")
bleurt_scores = [np.NAN]
try:
bert_scores = bertscore.compute(predictions=pred, references=refs, lang="en")[
"f1"
]
except Exception as e:
print(f"Bert error: {e}")
bert_scores = [np.NAN]
if bleu_results["bleu"] == 0:
# Sometimes bleu is 0.0 and this breaks the stderr computation.
bleu_results["bleu"] += 1e-5
results = {
"bleu": bleu_results["bleu"],
"rouge1": rouge_results["rouge1"],
"rouge2": rouge_results["rouge2"],
"rougeL": rouge_results["rougeL"],
"bleurt": np.mean(bleurt_scores),
"bert_score": np.mean(bert_scores),
}
return results
f1radgraph = F1RadGraph(reward_level="partial")
def doc_to_text(doc) -> str:
text = doc["extractive_notes_summ"]
a = re.search("IMPRESSION", text, re.IGNORECASE)
if a is not None:
a = a.start()
else:
a = -1
b = re.search("FINDING", text, re.IGNORECASE)
if b is not None:
b = b.start()
else:
b = -1
if a < b:
impressions = text[a:b].split(" ")[0]
findings = text[b:].split(" ")[0]
else:
impressions = text[a:].split(" ")[0]
findings = text[b:a].split(" ")[0]
if len(findings) < 5 < len(impressions):
findings = text[:a]
return "Given the findings: {}.\nSummarize the findings.".format(findings)
def doc_to_target(doc) -> str:
text = doc["extractive_notes_summ"]
a = re.search("IMPRESSION", text, re.IGNORECASE)
if a is not None:
a = a.start()
else:
a = -1
b = re.search("FINDING", text, re.IGNORECASE)
if b is not None:
b = b.start()
else:
b = -1
if a < b:
impressions = text[a:b].split(" ")[0]
else:
impressions = text[a:].split(" ")[0]
return impressions
def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
def process_results(doc, results):
pred, refs = [results[0]], [doc_to_target(doc)]
if len(refs[0]) < 5 or len(pred[0]) < 5:
return {
"bleu": np.NAN,
"rouge1": np.NAN,
"rouge2": np.NAN,
"rougeL": np.NAN,
"bleurt": np.NAN,
"bert_score": np.NAN,
"F1-Radgraph": np.NAN,
}
results = doc_eval(pred, refs)
try:
radgraph_score, _, _, _ = f1radgraph(hyps=pred, refs=refs)
except Exception:
radgraph_score = np.NAN
return {
"bleu": results["bleu"],
"rouge1": results["rouge1"],
"rouge2": results["rouge2"],
"rougeL": results["rougeL"],
"bleurt": results["bleurt"],
"bert_score": results["bert_score"],
"F1-Radgraph": radgraph_score,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment