"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "6570087f140222e01d418bc90f5080c57651451c"
Commit 2a67a469 authored by baberabb's avatar baberabb
Browse files

test

parent acf73a04
include: arc_easy.yaml include: arc_easy.yaml
task: arc_challenge task: arc_challenge
# test
dataset_name: ARC-Challenge dataset_name: ARC-Challenge
...@@ -3,7 +3,6 @@ group: ...@@ -3,7 +3,6 @@ group:
task: arc_easy task: arc_easy
dataset_path: ai2_arc dataset_path: ai2_arc
dataset_name: ARC-Easy dataset_name: ARC-Easy
# test
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
......
...@@ -3,15 +3,16 @@ Usage: ...@@ -3,15 +3,16 @@ Usage:
Writes csv and Markdown table to csv_file, md_file (below). Writes csv and Markdown table to csv_file, md_file (below).
""" """
import logging import logging
import os
from pathlib import Path from pathlib import Path
from typing import List, Union
import datasets import datasets
import pandas as pd import pandas as pd
from lm_eval import tasks from lm_eval import tasks
from lm_eval.tasks import TASK_REGISTRY from lm_eval.tasks import TASK_REGISTRY
from lm_eval.utils import load_yaml_config
from ..tests.utils import new_tasks
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -20,6 +21,39 @@ datasets.disable_caching() ...@@ -20,6 +21,39 @@ datasets.disable_caching()
tasks.initialize_tasks() tasks.initialize_tasks()
def load_changed_files(file_path: str) -> List[str]:
with open(file_path, "r") as f:
content = f.read()
words_list = [x for x in content.split()]
return words_list
def parser(full_path: List[str]) -> List[str]:
_output = set()
for x in full_path:
if x.endswith(".yaml"):
_output.add(load_yaml_config(x)["task"])
elif x.endswith(".py"):
path = [str(x) for x in (list(Path(x).parent.glob("*.yaml")))]
_output |= {load_yaml_config(x)["task"] for x in path}
return list(_output)
def new_tasks() -> Union[List[str], None]:
FILENAME = ".github/outputs/tasks_all_changed_and_modified_files.txt"
if os.path.exists(FILENAME):
# If tasks folder has changed then we get the list of files from FILENAME
# and parse the yaml files to get the task names.
return parser(load_changed_files(FILENAME))
elif os.getenv("API") is not None:
# Or if API has changed then we set the ENV variable API to True
# and run given tasks.
return ["arc_easy", "hellaswag", "piqa", "wikitext"]
# if both not true just do arc_easy
else:
return None
def check(tf): def check(tf):
if tf: if tf:
return "✓" return "✓"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment