Commit 748a9898 authored by lintangsutawika's avatar lintangsutawika
Browse files

can now process a benchmark that uses promptsource

parent 7411466c
...@@ -3,6 +3,7 @@ import yaml ...@@ -3,6 +3,7 @@ import yaml
from typing import List, Union from typing import List, Union
from lm_eval import utils from lm_eval import utils
from lm_eval import prompts
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
from lm_eval.api.registry import ( from lm_eval.api.registry import (
...@@ -14,6 +15,59 @@ from lm_eval.api.registry import ( ...@@ -14,6 +15,59 @@ from lm_eval.api.registry import (
) )
def register_configurable_task(config):
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
{"CONFIG": TaskConfig(**config)},
)
if "task" in config:
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
if "group" in config:
if type(config["group"]) == str:
group_name = [config["group"]]
else:
group_name = config["group"]
for group in group_name:
register_group(group)(SubClass)
return 0
def check_prompt_config(config):
all_configs = []
if "use_prompt" in config:
prompt_list = prompts.load_prompt_list(
use_prompt=config["use_prompt"],
dataset_name=config["dataset_path"],
subset_name=config["dataset_name"],
)
for idx, prompt_variation in enumerate(prompt_list):
all_configs.append(
{
**config,
**{"use_prompt": prompt_variation},
**{
"task": "_".join(
[
get_task_name_from_config(config),
"promptsource",
str(idx).zfill(2),
]
)
},
**{"output_type": "greedy_until"},
}
)
else:
all_configs.append(config)
return all_configs
def get_task_name_from_config(task_config): def get_task_name_from_config(task_config):
if "dataset_name" in task_config: if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config) return "{dataset_path}_{dataset_name}".format(**task_config)
...@@ -32,20 +86,10 @@ def include_task_folder(task_dir): ...@@ -32,20 +86,10 @@ def include_task_folder(task_dir):
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
try: try:
config = utils.load_yaml_config(yaml_path) config = utils.load_yaml_config(yaml_path)
all_configs = check_prompt_config(config)
for config in all_configs:
register_configurable_task(config)
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
{"CONFIG": TaskConfig(**config)},
)
if "task" in config:
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
if "group" in config:
for group in config["group"]:
register_group(group)(SubClass)
except Exception as error: except Exception as error:
eval_logger.warning( eval_logger.warning(
"Failed to load config in\n" "Failed to load config in\n"
...@@ -69,7 +113,24 @@ def include_benchmarks(task_dir, benchmark_dir="benchmarks"): ...@@ -69,7 +113,24 @@ def include_benchmarks(task_dir, benchmark_dir="benchmarks"):
assert "group" in yaml_config assert "group" in yaml_config
group = yaml_config["group"] group = yaml_config["group"]
task_list = yaml_config["task"] all_task_list = yaml_config["task"]
config_list = [
task for task in all_task_list if type(task) != str
]
task_list = [
task for task in all_task_list if type(task) == str
]
for task_config in config_list:
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS) task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names: for task in task_names:
if task in TASK_REGISTRY: if task in TASK_REGISTRY:
......
group: group: t0_eval
- t0_eval
task: task:
- dataset_path: super_glue # Coreference Resolution # Coreference Resolution
- dataset_path: super_glue
dataset_name: wsc.fixed dataset_name: wsc.fixed
use_prompt: promptsource use_prompt: promptsource:*
- dataset_path: winogrande # Coreference Resolution training_split: train
validation_split: validation
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
# Coreference Resolution
- dataset_path: winogrande
dataset_name: winogrande_xl dataset_name: winogrande_xl
use_prompt: promptsource use_prompt: promptsource:*
- dataset_path: super_glue # Natural Language Inference training_split: train
validation_split: validation
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
# Natural Language Inference
- dataset_path: super_glue
dataset_name: cb dataset_name: cb
use_prompt: promptsource use_prompt: promptsource:*
- dataset_path: super_glue # Natural Language Inference training_split: train
validation_split: validation
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
# Natural Language Inference
- dataset_path: super_glue
dataset_name: rte dataset_name: rte
use_prompt: promptsource use_prompt: promptsource:*
- dataset_path: anli # Natural Language Inference training_split: train
use_prompt: promptsource validation_split: validation
- dataset_path: super_glue # Sentence Completion metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
# Natural Language Inference
# - dataset_path: anli
# use_prompt: promptsource:*
# Sentence Completion
- dataset_path: super_glue
dataset_name: copa dataset_name: copa
use_prompt: promptsource use_prompt: promptsource:*
- dataset_path: hellaswag # Natural Language Inference training_split: train
use_prompt: promptsource validation_split: validation
- dataset_path: super_glue # Word Sense Disambiguation metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
# Natural Language Inference
- dataset_path: hellaswag
use_prompt: promptsource:*
training_split: train
validation_split: validation
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
# Word Sense Disambiguation
- dataset_path: super_glue
dataset_name: wic dataset_name: wic
use_prompt: promptsource use_prompt: promptsource:*
training_split: train
validation_split: validation
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
...@@ -108,6 +108,10 @@ class MultiChoice: ...@@ -108,6 +108,10 @@ class MultiChoice:
# Returns a list containing all values of the source_list that # Returns a list containing all values of the source_list that
# match at least one of the patterns # match at least one of the patterns
def pattern_match(patterns, source_list): def pattern_match(patterns, source_list):
if type(patterns) == str:
patterns = [patterns]
task_names = set() task_names = set()
for pattern in patterns: for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern): for matching in fnmatch.filter(source_list, pattern):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment