Commit 6c4222ec authored by baberabb's avatar baberabb
Browse files

add auto task_table

parent 6a1c19ed
name: Tasks Modified
on:
push:
branches:
- 'main'
pull_request:
branches:
- 'main'
......@@ -57,13 +54,22 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e '.[dev]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
# if new tasks are added, run tests on them
if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
run: python -m pytest tests/test_tasks.py -s -vv
# Edit task table if new tasks added
# copied from gpt-neox repo
- name: Update Docs
if: success()
run: |
python scripts/make_table_tasks.py
git config user.name github-actions
git config user.email github-actions@github.com
git add docs/task_table.md
git add docs/task_table.csv
git commit -m "Update Task Table automatically"
git push
# if api is modified, run tests on it
- name: Test more tasks with pytest
env:
......
"""
Usage:
python make_table_tasks.py --output <markdown_filename>
Writes csv and Markdown table to csv_file, md_file (below).
"""
import argparse
import logging
from pathlib import Path
from pytablewriter import MarkdownTableWriter
import datasets
import pandas as pd
from lm_eval import tasks
from lm_eval.tasks import TASK_REGISTRY
from tests.utils import new_tasks
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
datasets.disable_caching()
tasks.initialize_tasks()
def check(tf):
......@@ -21,34 +26,76 @@ def check(tf):
return " "
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output", type=str, default="task_table.md")
args = parser.parse_args()
writer = MarkdownTableWriter()
writer.headers = ["Task Name", "Train", "Val", "Test", "Val/Test Docs", "Metrics"]
def maketable(df):
headers = [
"Task Name",
"Group",
"Train",
"Val",
"Test",
"Val/Test Docs",
"Request Type,",
"Metrics",
]
values = []
tasks = tasks.TASK_REGISTRY.items()
tasks = sorted(tasks, key=lambda x: x[0])
for tname, Task in tasks:
if not df:
_tasks = tasks.TASK_REGISTRY.items()
_tasks = sorted(_tasks, key=lambda x: x[0])
else:
task_classes = new_tasks()
_tasks = [(x, TASK_REGISTRY.get(x)) for x in task_classes]
for tname, Task in _tasks:
task = Task()
v = [
tname,
task.config.group,
check(task.has_training_docs()),
check(task.has_validation_docs()),
check(task.has_test_docs()),
len(
list(
task.test_docs() if task.has_test_docs() else task.validation_docs()
task.test_docs()
if task.has_test_docs()
else task.validation_docs()
if task.has_validation_docs()
else task.training_docs()
)
),
task.config.output_type,
", ".join(task.aggregation().keys()),
]
logger.info(v)
values.append(v)
writer.value_matrix = values
table = writer.dumps()
with open(args.output, "w") as f:
if not df:
df = pd.DataFrame(values, columns=headers)
table = df.to_markdown(index=False)
else:
for new_row in values:
tname = new_row[0]
if tname in df["Task Name"].values:
# If task name exists, update the row
df.loc[df["Task Name"] == tname] = new_row
else:
# If task name doesn't exist, append a new row
series = pd.Series(new_row, index=df.columns)
df = pd.concat([df, series.to_frame().T], ignore_index=True)
df = df.sort_values(by=["Task Name"])
table = df.to_markdown(index=False)
return df, table
if __name__ == "__main__":
csv_file = Path(f"{Path(__file__).parent.parent.resolve()}/docs/task_guide.csv")
md_file = Path(f"{Path(__file__).parent.parent.resolve()}/docs/task_guide.md")
try:
df = pd.read_csv(csv_file)
except FileNotFoundError:
df = None
df, table = maketable(df=df)
with open(md_file, "w") as f:
f.write(table)
with open(csv_file, "w") as f:
df.to_csv(f, index=False)
from itertools import islice
import datasets
import pytest
import lm_eval.tasks as tasks
......@@ -8,6 +9,7 @@ from lm_eval.api.task import ConfigurableTask
from .utils import new_tasks
datasets.disable_caching()
tasks.initialize_tasks()
# Default Task
TASKS = ["arc_easy"]
......
......@@ -45,4 +45,4 @@ def new_tasks() -> Union[List[str], None]:
return ["arc_easy", "hellaswag", "piqa", "wikitext"]
# if both not true just do arc_easy
else:
return
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment