Commit ceaef430 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

working fast (no-load) task table

parent f66fc06f
...@@ -9,6 +9,7 @@ from typing import List, Union ...@@ -9,6 +9,7 @@ from typing import List, Union
import datasets import datasets
import pandas as pd import pandas as pd
from tqdm import tqdm
from lm_eval import tasks from lm_eval import tasks
from lm_eval.utils import load_yaml_config from lm_eval.utils import load_yaml_config
...@@ -17,7 +18,7 @@ from lm_eval.utils import load_yaml_config ...@@ -17,7 +18,7 @@ from lm_eval.utils import load_yaml_config
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
datasets.disable_caching() datasets.disable_caching()
task_manager = tasks.TaskManager task_manager = tasks.TaskManager()
def load_changed_files(file_path: str) -> List[str]: def load_changed_files(file_path: str) -> List[str]:
...@@ -38,19 +39,23 @@ def parser(full_path: List[str]) -> List[str]: ...@@ -38,19 +39,23 @@ def parser(full_path: List[str]) -> List[str]:
return list(_output) return list(_output)
def new_tasks() -> Union[List[str], None]: def new_tasks(df=None) -> Union[List[str], None]:
new_tasks = []
FILENAME = ".github/outputs/tasks_all_changed_and_modified_files.txt" FILENAME = ".github/outputs/tasks_all_changed_and_modified_files.txt"
if os.path.exists(FILENAME): if os.path.exists(FILENAME):
# If tasks folder has changed then we get the list of files from FILENAME # If tasks folder has changed then we get the list of files from FILENAME
# and parse the yaml files to get the task names. # and parse the yaml files to get the task names.
return parser(load_changed_files(FILENAME)) # (for when run in CI)
elif os.getenv("API") is not None: new_tasks.extend(parser(load_changed_files(FILENAME)))
# Or if API has changed then we set the ENV variable API to True # if we already have a (partial) task table created, only add tasks
# and run given tasks. # which aren't already in task table
return ["arc_easy", "hellaswag", "piqa", "wikitext"] if df is not None:
_tasks = task_manager.all_tasks
_tasks = [k for k in _tasks if k not in df["Task Name"].values]
new_tasks.extend(_tasks)
# if both not true just do arc_easy # if both not true just do arc_easy
else: return new_tasks
return None
def check(tf): def check(tf):
...@@ -64,47 +69,84 @@ def maketable(df): ...@@ -64,47 +69,84 @@ def maketable(df):
headers = [ headers = [
"Task Name", "Task Name",
"Group", "Group",
"Train", # "Train",
"Val", # "Val",
"Test", # "Test",
"Val/Test Docs", # "Val/Test Docs",
"Request Type,", "Request Type",
"Filters",
"Metrics", "Metrics",
] ]
values = [] values = []
if not df: if df is None:
_tasks = task_manager.TASK_REGISTRY.items() _tasks = task_manager.all_tasks
_tasks = sorted(_tasks, key=lambda x: x[0])
else: else:
task_classes = new_tasks() _tasks = new_tasks(df=df)
_tasks = [(x, task_manager.TASK_REGISTRY.get(x)) for x in task_classes] # _tasks = [(x, task_manager.load_task_or_group(x)) for x in task_classes]
count = 0 # count = 0
for tname, Task in _tasks: for tname in tqdm(_tasks):
task = Task() print(tname)
# try:
# if not tname in ["advanced_ai_risk", "arithmetic", "bbh_fewshot", "bbh_cot_fewshot", "bbh_cot_zeroshot"]:
# task = task_manager.load_task_or_group(tname)
# else:
# continue
# if isinstance(list(task.values())[0], tuple): # is group, not a solo task
# del task
# continue
# else:
# task = task[tname]
# # except Exception as e:
# # print(e)
# # continue
task_config = task_manager._get_config(tname)
if not task_config:
continue
# TODO: also catch benchmark configs like flan
if not isinstance(task_config["task"], str):
continue
if task_config.get("class", None):
continue
v = [ v = [
tname, tname,
task.config.group, task_config.get("group", None),
check(task.has_training_docs()), # check(True),
check(task.has_validation_docs()), # check(True),
check(task.has_test_docs()), # check(True),
len( # -1,
list( task_config.get("output_type", "greedy_until"),
task.test_docs() ", ".join(
if task.has_test_docs() str(f["name"])
else task.validation_docs() for f in task_config.get("filter_list", [{"name": "none"}])
if task.has_validation_docs()
else task.training_docs()
)
), ),
task.config.output_type, ", ".join(str(metric["metric"]) for metric in task_config["metric_list"]),
", ".join(task.aggregation().keys()),
] ]
# v = [
# tname,
# task.CONFIG.group,
# check(task.has_training_docs()),
# check(task.has_validation_docs()),
# check(task.has_test_docs()),
# len(
# list(
# task.test_docs()
# if task.has_test_docs()
# else task.validation_docs()
# if task.has_validation_docs()
# else task.training_docs()
# )
# ),
# task.config.output_type,
# ", ".join(task.aggregation().keys()),
# ]
logger.info(v) logger.info(v)
values.append(v) values.append(v)
count += 1 # count += 1
if count == 10: # if count >= 20:
break # break
if not df:
# del task
if df is None:
df = pd.DataFrame(values, columns=headers) df = pd.DataFrame(values, columns=headers)
table = df.to_markdown(index=False) table = df.to_markdown(index=False)
else: else:
...@@ -130,7 +172,7 @@ if __name__ == "__main__": ...@@ -130,7 +172,7 @@ if __name__ == "__main__":
df = pd.read_csv(csv_file) df = pd.read_csv(csv_file)
except FileNotFoundError: except FileNotFoundError:
df = None df = None
df = None
df, table = maketable(df=df) df, table = maketable(df=df)
with open(md_file, "w") as f: with open(md_file, "w") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment