Unverified Commit eea16d36 authored by Jess's avatar Jess Committed by GitHub
Browse files

Merge branch 'EleutherAI:main' into main

parents 72f5f4b1 885f48d6
try:
from unitxt import evaluate
except ImportError:
raise ImportError(
"Package 'unitxt' is not installed. To install it, use `pip install 'lm_eval[unitxt]'`"
)
from lm_eval.api.registry import AGGREGATION_REGISTRY, METRIC_REGISTRY, register_metric
def unitxt_agg_metric(items):
preds = [pred[0] for pred, _, _ in items]
refs = [ref for _, ref, _ in items]
metric_name = items[0][2].replace("unitxt_", "metrics.")
for ref in refs:
ref["metrics"] = [metric_name]
result_metrics = evaluate(preds, refs)
return result_metrics[0]["score"]["global"]["score"]
AGGREGATION_REGISTRY["unitxt"] = unitxt_agg_metric
def unitxt_metric(items): # This is a passthrough function
return items
def process_results(doc, results):
metrics = doc["metrics"]
scores = {}
for metric in metrics:
metric = metric.replace("metrics.", "unitxt_")
scores[metric] = (results, doc, metric)
if metric not in METRIC_REGISTRY:
register_metric(
metric=metric,
higher_is_better=True,
output_type="generate_until",
aggregation="unitxt",
)(unitxt_metric)
return scores
#
include: unitxt_tasks.summarization.abstractive
task: xsum
dataset_name: card=cards.xsum,template=templates.summarization.abstractive.full
include: unitxt_tasks.classification.multi_class
task: yahoo_answers_topics
dataset_name: card=cards.yahoo_answers_topics,template=templates.classification.multi_class.title
......@@ -242,7 +242,7 @@ class Reorderer:
return res
def make_table(result_dict, column: str = "results"):
def make_table(result_dict, column: str = "results", sort_results: bool = True):
"""Generate table of results."""
from pytablewriter import LatexTableWriter, MarkdownTableWriter
......@@ -269,7 +269,12 @@ def make_table(result_dict, column: str = "results"):
values = []
for k, dic in result_dict[column].items():
keys = result_dict[column].keys()
if sort_results:
# sort entries alphabetically
keys = sorted(keys)
for k in keys:
dic = result_dict[column][k]
version = result_dict["versions"].get(k, "N/A")
n = str(result_dict["n-shot"][k])
......
......@@ -76,6 +76,7 @@ testing = ["pytest", "pytest-cov", "pytest-xdist"]
vllm = ["vllm==0.3.2"]
zeno = ["pandas", "zeno-client"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"]
unitxt = ["unitxt"]
all = [
"lm_eval[anthropic]",
"lm_eval[dev]",
......@@ -94,6 +95,7 @@ all = [
"lm_eval[vllm]",
"lm_eval[zeno]",
"lm_eval[wandb]",
"lm_eval[unitxt]"
]
[tool.ruff.lint]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment