Commit ceaef430 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

working fast (no-load) task table

parent f66fc06f
......@@ -9,6 +9,7 @@ from typing import List, Union
import datasets
import pandas as pd
from tqdm import tqdm
from lm_eval import tasks
from lm_eval.utils import load_yaml_config
......@@ -17,7 +18,7 @@ from lm_eval.utils import load_yaml_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
datasets.disable_caching()
task_manager = tasks.TaskManager
task_manager = tasks.TaskManager()
def load_changed_files(file_path: str) -> List[str]:
......@@ -38,19 +39,23 @@ def parser(full_path: List[str]) -> List[str]:
return list(_output)
def new_tasks() -> Union[List[str], None]:
def new_tasks(df=None) -> Union[List[str], None]:
new_tasks = []
FILENAME = ".github/outputs/tasks_all_changed_and_modified_files.txt"
if os.path.exists(FILENAME):
# If tasks folder has changed then we get the list of files from FILENAME
# and parse the yaml files to get the task names.
return parser(load_changed_files(FILENAME))
elif os.getenv("API") is not None:
# Or if API has changed then we set the ENV variable API to True
# and run given tasks.
return ["arc_easy", "hellaswag", "piqa", "wikitext"]
# (for when run in CI)
new_tasks.extend(parser(load_changed_files(FILENAME)))
# if we already have a (partial) task table created, only add tasks
# which aren't already in task table
if df is not None:
_tasks = task_manager.all_tasks
_tasks = [k for k in _tasks if k not in df["Task Name"].values]
new_tasks.extend(_tasks)
# if both not true just do arc_easy
else:
return None
return new_tasks
def check(tf):
......@@ -64,47 +69,84 @@ def maketable(df):
headers = [
"Task Name",
"Group",
"Train",
"Val",
"Test",
"Val/Test Docs",
"Request Type,",
# "Train",
# "Val",
# "Test",
# "Val/Test Docs",
"Request Type",
"Filters",
"Metrics",
]
values = []
if not df:
_tasks = task_manager.TASK_REGISTRY.items()
_tasks = sorted(_tasks, key=lambda x: x[0])
if df is None:
_tasks = task_manager.all_tasks
else:
task_classes = new_tasks()
_tasks = [(x, task_manager.TASK_REGISTRY.get(x)) for x in task_classes]
count = 0
for tname, Task in _tasks:
task = Task()
_tasks = new_tasks(df=df)
# _tasks = [(x, task_manager.load_task_or_group(x)) for x in task_classes]
# count = 0
for tname in tqdm(_tasks):
print(tname)
# try:
# if not tname in ["advanced_ai_risk", "arithmetic", "bbh_fewshot", "bbh_cot_fewshot", "bbh_cot_zeroshot"]:
# task = task_manager.load_task_or_group(tname)
# else:
# continue
# if isinstance(list(task.values())[0], tuple): # is group, not a solo task
# del task
# continue
# else:
# task = task[tname]
# # except Exception as e:
# # print(e)
# # continue
task_config = task_manager._get_config(tname)
if not task_config:
continue
# TODO: also catch benchmark configs like flan
if not isinstance(task_config["task"], str):
continue
if task_config.get("class", None):
continue
v = [
tname,
task.config.group,
check(task.has_training_docs()),
check(task.has_validation_docs()),
check(task.has_test_docs()),
len(
list(
task.test_docs()
if task.has_test_docs()
else task.validation_docs()
if task.has_validation_docs()
else task.training_docs()
)
task_config.get("group", None),
# check(True),
# check(True),
# check(True),
# -1,
task_config.get("output_type", "greedy_until"),
", ".join(
str(f["name"])
for f in task_config.get("filter_list", [{"name": "none"}])
),
task.config.output_type,
", ".join(task.aggregation().keys()),
", ".join(str(metric["metric"]) for metric in task_config["metric_list"]),
]
# v = [
# tname,
# task.CONFIG.group,
# check(task.has_training_docs()),
# check(task.has_validation_docs()),
# check(task.has_test_docs()),
# len(
# list(
# task.test_docs()
# if task.has_test_docs()
# else task.validation_docs()
# if task.has_validation_docs()
# else task.training_docs()
# )
# ),
# task.config.output_type,
# ", ".join(task.aggregation().keys()),
# ]
logger.info(v)
values.append(v)
count += 1
if count == 10:
break
if not df:
# count += 1
# if count >= 20:
# break
# del task
if df is None:
df = pd.DataFrame(values, columns=headers)
table = df.to_markdown(index=False)
else:
......@@ -130,7 +172,7 @@ if __name__ == "__main__":
df = pd.read_csv(csv_file)
except FileNotFoundError:
df = None
df = None
df, table = maketable(df=df)
with open(md_file, "w") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment