Commit d1c3cb3d authored by lintangsutawika's avatar lintangsutawika
Browse files

expanded benchmark to allow new source of prompt templates

parent 21e1ed17
......@@ -14,7 +14,7 @@ from lm_eval.api.registry import (
def include_benchmarks(task_dir):
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
if (subdirs == [] or "__pycache__" in subdirs) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
try:
......@@ -23,6 +23,9 @@ def include_benchmarks(task_dir):
with open(benchmark_path, "rb") as file:
yaml_config = yaml.full_load(file)
if "prompts" in yaml_config:
continue # Skip it
assert "group" in yaml_config
group = yaml_config["group"]
all_task_list = yaml_config["task"]
......@@ -34,6 +37,16 @@ def include_benchmarks(task_dir):
]
for task_config in config_list:
yaml_dir = os.path.dirname(benchmark_path)
task_config = utils.load_yaml_config(
yaml_config=task_config, yaml_dir=yaml_dir
)
if "use_prompt" in task_config:
if "yaml" in task_config["use_prompt"]:
task_config["use_prompt"] = os.path.join(
root, task_config["use_prompt"]
)
var_configs = check_prompt_config(
{
**task_config,
......
group: zero-shot-cot
output_type: greedy_until
validation_split: validation
doc_to_target: "{{answer}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
generation_kwargs:
until:
- "\n\n"
do_sample: false
temperature: 0.0
filter_list:
- name: "get-answer"
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)"
- function: "take_first"
group: flan_cot
task:
- include: cot_template_yaml
dataset_path: super_glue
dataset_name: boolq
use_prompt: promptsource:*
validation_split: validation
- include: cot_template_yaml
dataset_path: super_glue
dataset_name: rte
use_prompt: promptsource:*
validation_split: validation
- include: cot_template_yaml
task: anli_r1
dataset_path: anli
use_prompt: promptsource:*
validation_split: dev_r1
- include: cot_template_yaml
task: anli_r2
dataset_path: anli
use_prompt: promptsource:*
validation_split: dev_r2
- include: cot_template_yaml
task: anli_r3
dataset_path: anli
use_prompt: promptsource:*
validation_split: dev_r3
- include: cot_template_yaml
task: ai2_arc
dataset_path: ARC-Easy
use_prompt: promptsource:*
validation_split: validation
- include: cot_template_yaml
task: ai2_arc
dataset_path: ARC-Challange
use_prompt: promptsource:*
validation_split: validation
# Flan Prompt Templates
prompts:
"template-0":
doc_to_text: "{{text}}\n\nCan we conclude that {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-1":
doc_to_text: "{{text}}\n\nIs it true that {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-2":
doc_to_text: "{{text}}\n\n{{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-3":
doc_to_text: "Text: {{text}}\n\nQuestion: {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-4":
doc_to_text: "{{text}}\n\nWhat's the best answer to this question: {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-5":
doc_to_text: "{{text}}\nBased on the above text what's the best answer to this question: {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-6":
doc_to_text: "{{text}}\nAnswer this question making sure that the answer is supposed by the text: {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-7":
doc_to_text: "{{text}}\n\nIs the following statement correct based on the text\n\n{{question}}\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-8":
doc_to_text: "{{title}}\n\n{{text}}\n\nIs this statement correct \"{{question}}\"?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-9":
doc_to_text: "Is it true that {{question}} based on the following text?\n\n{{text}}\n\n{{options_}}"
doc_to_target: "{{answer}}"
group: flan_held_in
task:
- include: held_in_template_yaml
dataset_path: super_glue
dataset_name: boolq
use_prompt: flan_boolq.yaml:*
validation_split: validation
# - include: held_in_template_yaml
# dataset_path: super_glue
# dataset_name: rte
# use_prompt: local:*
# validation_split: validation
# - include: held_in_template_yaml
# task: anli_r1
# dataset_path: anli
# use_prompt: local:*
# validation_split: dev_r1
# - include: held_in_template_yaml
# task: anli_r2
# dataset_path: anli
# use_prompt: local:*
# validation_split: dev_r2
# - include: held_in_template_yaml
# task: anli_r3
# dataset_path: anli
# use_prompt: local:*
# validation_split: dev_r3
# - include: held_in_template_yaml
# task: ai2_arc
# dataset_path: ARC-Easy
# use_prompt: local:*
# validation_split: validation
# - include: held_in_template_yaml
# task: ai2_arc
# dataset_path: ARC-Challange
# use_prompt: local:*
# validation_split: validation
output_type: greedy_until
validation_split: validation
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
generation_kwargs:
until:
- "\n\n"
do_sample: false
temperature: 0.0
......@@ -44,6 +44,14 @@ def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
raise ValueError(
f"{prompt_name} not in prompt list {prompts.all_template_names}"
)
elif ".yaml" in category_name:
import yaml
with open(category_name, "rb") as file:
prompt_yaml_file = yaml.full_load(file)
prompt_string = prompt_yaml_file["prompts"][prompt_name]
return PromptString(prompt_string)
else:
try:
return PROMPT_REGISTRY[category_name][prompt_name]
......@@ -56,13 +64,42 @@ def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwargs):
from promptsource.templates import DatasetTemplates
category_name, prompt_name = use_prompt.split(":")
if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)
if category_name == "promptsource":
from promptsource.templates import DatasetTemplates
if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(
dataset_name=dataset_name, subset_name=subset_name
)
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
elif ".yaml" in category_name:
import yaml
with open(category_name, "rb") as file:
prompt_yaml_file = yaml.full_load(file)
prompt_list = utils.pattern_match(
prompt_name, prompt_yaml_file["prompts"].keys()
)
category_name, prompt_name = use_prompt.split(":")
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
return [":".join([category_name, prompt]) for prompt in prompt_list]
class PromptString:
def __init__(prompt_string):
self.prompt_string = prompt_string
def apply(self, doc):
doc_to_text = self.prompt_string["doc_to_text"]
doc_to_target = self.prompt_string["doc_to_target"]
text_string = utils.apply_template(doc_to_text, doc)
target_string = utils.apply_template(doc_to_target, doc)
return [text_string, target_string]
......@@ -412,39 +412,43 @@ def import_function(loader, node):
yaml.add_constructor("!function", import_function)
def load_yaml_config(yaml_path):
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
yaml_dir = os.path.dirname(yaml_path)
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if type(include_path) == str:
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
try:
included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config)
except Exception as ex:
# If failed to load, ignore
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
if yaml_config is None:
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
yaml_dir = os.path.dirname(yaml_path)
assert yaml_dir is not None
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if type(include_path) == str:
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
try:
included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config)
except Exception as ex:
# If failed to load, ignore
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
def regex_replace(string, pattern, repl, count=0):
......
......@@ -11,6 +11,7 @@ from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger
from lm_eval.tasks import include_task_folder
from lm_eval.benchmarks import include_benchmarks
os.environ["TOKENIZERS_PARALLELISM"] = "false"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment