Unverified Commit 37bfb4f5 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #228 from jon-tow/truthfulqa-ver-bump

Remove `truthfulqa` dependency on `t5` and bump version
parents 8a97b28f 376855d5
...@@ -42,7 +42,7 @@ jobs: ...@@ -42,7 +42,7 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest - name: Test with pytest
run: | run: |
pytest --cov=lm_eval/ tests/ pytest -vv --cov=lm_eval/ tests/
- name: Upload to codecov - name: Upload to codecov
run: | run: |
bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN
...@@ -43,7 +43,7 @@ from . import pile ...@@ -43,7 +43,7 @@ from . import pile
from . import wikitext from . import wikitext
from . import lambada_multilingual from . import lambada_multilingual
from . import mutual from . import mutual
# from . import truthfulqa from . import truthfulqa
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -148,8 +148,8 @@ TASK_REGISTRY = { ...@@ -148,8 +148,8 @@ TASK_REGISTRY = {
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue, "ethics_virtue": hendrycks_ethics.EthicsVirtue,
# "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
# "truthfulqa_gen": truthfulqa.TruthfulQAGeneration, "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue # dialogue
"mutual": mutual.MuTual, "mutual": mutual.MuTual,
......
...@@ -22,14 +22,14 @@ we could try this? ...@@ -22,14 +22,14 @@ we could try this?
import csv import csv
import json import json
import numpy as np import numpy as np
import sacrebleu
from rouge_score import rouge_scorer, scoring
from lm_eval.base import rf, Task from lm_eval.base import rf, Task
from pathlib import Path from pathlib import Path
from best_download import download_file from best_download import download_file
from ..metrics import mean from ..metrics import mean
from datasets import load_metric from datasets import load_metric
from t5.evaluation import metrics
bleurt = load_metric("bleurt", cache_dir="lm_cache")
# The default QA preset prompt for all models. # The default QA preset prompt for all models.
QA_PROMPT = ( QA_PROMPT = (
...@@ -49,7 +49,7 @@ QA_PROMPT = ( ...@@ -49,7 +49,7 @@ QA_PROMPT = (
class TruthfulQAMultipleChoice(Task): class TruthfulQAMultipleChoice(Task):
VERSION = 0 VERSION = 1
DATASET_PATH = Path('data/truthfulqa/mc') DATASET_PATH = Path('data/truthfulqa/mc')
def download(self): def download(self):
...@@ -150,9 +150,13 @@ class TruthfulQAMultipleChoice(Task): ...@@ -150,9 +150,13 @@ class TruthfulQAMultipleChoice(Task):
class TruthfulQAGeneration(Task): class TruthfulQAGeneration(Task):
VERSION = 0 VERSION = 1
DATASET_PATH = Path('data/truthfulqa/generation') DATASET_PATH = Path('data/truthfulqa/generation')
def __init__(self):
super().__init__()
self.bleurt = load_metric("bleurt", cache_dir="lm_cache")
def download(self): def download(self):
if self.DATASET_PATH.exists(): if self.DATASET_PATH.exists():
return return
...@@ -249,10 +253,10 @@ class TruthfulQAGeneration(Task): ...@@ -249,10 +253,10 @@ class TruthfulQAGeneration(Task):
# Process the sentence-level BLEURT, BLEU, and ROUGE for similarity measures. # Process the sentence-level BLEURT, BLEU, and ROUGE for similarity measures.
# BLEURT # BLEURT
bleurt_scores_true = bleurt.compute( bleurt_scores_true = self.bleurt.compute(
predictions=[completion] * len(true_refs), predictions=[completion] * len(true_refs),
references=true_refs)['scores'] references=true_refs)['scores']
bleurt_scores_false = bleurt.compute( bleurt_scores_false = self.bleurt.compute(
predictions=[completion] * len(false_refs), predictions=[completion] * len(false_refs),
references=false_refs)['scores'] references=false_refs)['scores']
bleurt_correct = max(bleurt_scores_true) bleurt_correct = max(bleurt_scores_true)
...@@ -262,7 +266,7 @@ class TruthfulQAGeneration(Task): ...@@ -262,7 +266,7 @@ class TruthfulQAGeneration(Task):
bleurt_acc = int(bleurt_correct > bleurt_incorrect) bleurt_acc = int(bleurt_correct > bleurt_incorrect)
# BLEU # BLEU
bleu_scores = [metrics.bleu([ref], [completion])['bleu'] for ref in all_refs] bleu_scores = [self.bleu([[ref]], [completion]) for ref in all_refs]
bleu_correct = np.nanmax(bleu_scores[:len(true_refs)]) bleu_correct = np.nanmax(bleu_scores[:len(true_refs)])
bleu_incorrect = np.nanmax(bleu_scores[len(true_refs):]) bleu_incorrect = np.nanmax(bleu_scores[len(true_refs):])
bleu_max = bleu_correct bleu_max = bleu_correct
...@@ -270,7 +274,7 @@ class TruthfulQAGeneration(Task): ...@@ -270,7 +274,7 @@ class TruthfulQAGeneration(Task):
bleu_acc = int(bleu_correct > bleu_incorrect) bleu_acc = int(bleu_correct > bleu_incorrect)
# ROUGE-N # ROUGE-N
rouge_scores = [metrics.rouge([ref], [completion]) for ref in all_refs] rouge_scores = [self.rouge([ref], [completion]) for ref in all_refs]
# ROUGE-1 # ROUGE-1
rouge1_scores = [score['rouge1'] for score in rouge_scores] rouge1_scores = [score['rouge1'] for score in rouge_scores]
rouge1_correct = np.nanmax(rouge1_scores[:len(true_refs)]) rouge1_correct = np.nanmax(rouge1_scores[:len(true_refs)])
...@@ -360,3 +364,50 @@ class TruthfulQAGeneration(Task): ...@@ -360,3 +364,50 @@ class TruthfulQAGeneration(Task):
"rougeL_acc": True, "rougeL_acc": True,
"rougeL_diff": True, "rougeL_diff": True,
} }
def bleu(self, refs, preds):
"""
Returns `t5` style BLEU scores. See the related implementation:
https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L41
:param refs:
A `list` of `list` of reference `str`s.
:param preds:
A `list` of predicted `str`s.
"""
score = sacrebleu.corpus_bleu(
preds,
refs,
smooth_method="exp",
smooth_value=0.0,
force=False,
lowercase=False,
tokenize="intl",
use_effective_order=False
).score
return score
def rouge(self, refs, preds):
"""
Returns `t5` style ROUGE scores. See the related implementation:
https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L68
:param refs:
A `list` of reference `strs`.
:param preds:
A `list` of predicted `strs`.
"""
rouge_types = ["rouge1", "rouge2", "rougeLsum"]
scorer = rouge_scorer.RougeScorer(rouge_types)
# Add newlines between sentences to correctly compute `rougeLsum`.
def _prepare_summary(summary):
summary = summary.replace(" . ", ".\n")
return summary
# Accumulate confidence intervals.
aggregator = scoring.BootstrapAggregator()
for ref, pred in zip(refs, preds):
ref = _prepare_summary(ref)
pred = _prepare_summary(pred)
aggregator.add_scores(scorer.score(ref, pred))
result = aggregator.aggregate()
return {type: result[type].mid.fmeasure*100 for type in rouge_types}
...@@ -30,6 +30,8 @@ setuptools.setup( ...@@ -30,6 +30,8 @@ setuptools.setup(
"sqlitedict==1.6.0", "sqlitedict==1.6.0",
"pytablewriter==0.58.0", "pytablewriter==0.58.0",
"sacrebleu==1.5.0", "sacrebleu==1.5.0",
"rouge-score==0.0.4",
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
"pycountry==20.7.3", "pycountry==20.7.3",
"numexpr==2.7.2", "numexpr==2.7.2",
"lm_dataformat==0.0.20", "lm_dataformat==0.0.20",
...@@ -42,8 +44,5 @@ setuptools.setup( ...@@ -42,8 +44,5 @@ setuptools.setup(
"openai==0.6.4", "openai==0.6.4",
"jieba==0.42.1", "jieba==0.42.1",
"nagisa==0.2.7", "nagisa==0.2.7",
"t5==0.7.1",
"tensorflow-estimator==2.6.0",
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
] ]
) )
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import os import os
import json import json
import hashlib import hashlib
import collections
os.makedirs("tests/testdata", exist_ok=True) os.makedirs("tests/testdata", exist_ok=True)
...@@ -15,7 +16,11 @@ def assert_target(name, ob): ...@@ -15,7 +16,11 @@ def assert_target(name, ob):
fname = f"tests/testdata/{name}.json" fname = f"tests/testdata/{name}.json"
if os.path.exists(fname): if os.path.exists(fname):
with open(fname) as fh: with open(fname) as fh:
assert json.load(fh) == json.loads(json.dumps(ob, sort_keys=True)) # Use relative tolerance of 1e-5 and absolute tolerance of 1e-8
# assuming most metrics work on `float32` values, which is the common
# default floating type across popular libraries (PyTorch, Tensorflow, and JAX).
assert flatten(json.load(fh)) == pytest.approx(
flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8)
else: else:
with open(fname, 'w') as fh: with open(fname, 'w') as fh:
json.dump(ob, fh, sort_keys=True) json.dump(ob, fh, sort_keys=True)
...@@ -29,6 +34,17 @@ def assert_target_hashed(name, ob): ...@@ -29,6 +34,17 @@ def assert_target_hashed(name, ob):
with open(fname, 'w') as fh: with open(fname, 'w') as fh:
fh.write(hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest()) fh.write(hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest())
# from https://stackoverflow.com/a/6027615
def flatten(d, parent_key='', sep='.'):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
# make sure eval results for a task version are stable # make sure eval results for a task version are stable
......
1a280973bbac2b7ac29dd64dddac474fb4749585f7de893483b4034814466c67
\ No newline at end of file
{"results": {"truthfulqa_gen": {"bleu_acc": 0.0, "bleu_acc_stderr": 0.0, "bleu_diff": 0.0, "bleu_diff_stderr": 0.0, "bleu_max": 0.0, "bleu_max_stderr": 0.0, "bleurt_acc": 0.835985312117503, "bleurt_acc_stderr": 0.012962704327492454, "bleurt_diff": 0.14077322143090107, "bleurt_diff_stderr": 0.005459888909582694, "bleurt_max": -1.4399358725752065, "bleurt_max_stderr": 0.0022126992369197133, "rouge1_acc": 0.0, "rouge1_acc_stderr": 0.0, "rouge1_diff": 0.0, "rouge1_diff_stderr": 0.0, "rouge1_max": 0.0, "rouge1_max_stderr": 0.0, "rouge2_acc": 0.0, "rouge2_acc_stderr": 0.0, "rouge2_diff": 0.0, "rouge2_diff_stderr": 0.0, "rouge2_max": 0.0, "rouge2_max_stderr": 0.0, "rougeL_acc": 0.0, "rougeL_acc_stderr": 0.0, "rougeL_diff": 0.0, "rougeL_diff_stderr": 0.0, "rougeL_max": 0.0, "rougeL_max_stderr": 0.0}}, "versions": {"truthfulqa_gen": 1}}
\ No newline at end of file
1e07020e9cf41d46ed65312eb39d2b8e6599673d4f0d6b67c0d0eba0efb493bb
\ No newline at end of file
{"results": {"truthfulqa_mc": {"mc1": 0.23255813953488372, "mc1_stderr": 0.01478915753108052, "mc2": 0.4462325560722362, "mc2_stderr": 0.004986523944692003}}, "versions": {"truthfulqa_mc": 1}}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment