"mmdet/vscode:/vscode.git/clone" did not exist on "1ec92ef4094c239ae138ace52787a2496e4cb8df"
Commit d1c3cb3d authored by lintangsutawika's avatar lintangsutawika
Browse files

expanded benchmark to allow new source of prompt templates

parent 21e1ed17
...@@ -14,7 +14,7 @@ from lm_eval.api.registry import ( ...@@ -14,7 +14,7 @@ from lm_eval.api.registry import (
def include_benchmarks(task_dir): def include_benchmarks(task_dir):
for root, subdirs, file_list in os.walk(task_dir): for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0): if (subdirs == [] or "__pycache__" in subdirs) and (len(file_list) > 0):
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
try: try:
...@@ -23,6 +23,9 @@ def include_benchmarks(task_dir): ...@@ -23,6 +23,9 @@ def include_benchmarks(task_dir):
with open(benchmark_path, "rb") as file: with open(benchmark_path, "rb") as file:
yaml_config = yaml.full_load(file) yaml_config = yaml.full_load(file)
if "prompts" in yaml_config:
continue # Skip it
assert "group" in yaml_config assert "group" in yaml_config
group = yaml_config["group"] group = yaml_config["group"]
all_task_list = yaml_config["task"] all_task_list = yaml_config["task"]
...@@ -34,6 +37,16 @@ def include_benchmarks(task_dir): ...@@ -34,6 +37,16 @@ def include_benchmarks(task_dir):
] ]
for task_config in config_list: for task_config in config_list:
yaml_dir = os.path.dirname(benchmark_path)
task_config = utils.load_yaml_config(
yaml_config=task_config, yaml_dir=yaml_dir
)
if "use_prompt" in task_config:
if "yaml" in task_config["use_prompt"]:
task_config["use_prompt"] = os.path.join(
root, task_config["use_prompt"]
)
var_configs = check_prompt_config( var_configs = check_prompt_config(
{ {
**task_config, **task_config,
......
group: zero-shot-cot
output_type: greedy_until
validation_split: validation
doc_to_target: "{{answer}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
generation_kwargs:
until:
- "\n\n"
do_sample: false
temperature: 0.0
filter_list:
- name: "get-answer"
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)"
- function: "take_first"
group: flan_cot
task:
- include: cot_template_yaml
dataset_path: super_glue
dataset_name: boolq
use_prompt: promptsource:*
validation_split: validation
- include: cot_template_yaml
dataset_path: super_glue
dataset_name: rte
use_prompt: promptsource:*
validation_split: validation
- include: cot_template_yaml
task: anli_r1
dataset_path: anli
use_prompt: promptsource:*
validation_split: dev_r1
- include: cot_template_yaml
task: anli_r2
dataset_path: anli
use_prompt: promptsource:*
validation_split: dev_r2
- include: cot_template_yaml
task: anli_r3
dataset_path: anli
use_prompt: promptsource:*
validation_split: dev_r3
- include: cot_template_yaml
task: ai2_arc
dataset_path: ARC-Easy
use_prompt: promptsource:*
validation_split: validation
- include: cot_template_yaml
task: ai2_arc
dataset_path: ARC-Challange
use_prompt: promptsource:*
validation_split: validation
# Flan Prompt Templates
prompts:
"template-0":
doc_to_text: "{{text}}\n\nCan we conclude that {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-1":
doc_to_text: "{{text}}\n\nIs it true that {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-2":
doc_to_text: "{{text}}\n\n{{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-3":
doc_to_text: "Text: {{text}}\n\nQuestion: {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-4":
doc_to_text: "{{text}}\n\nWhat's the best answer to this question: {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-5":
doc_to_text: "{{text}}\nBased on the above text what's the best answer to this question: {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-6":
doc_to_text: "{{text}}\nAnswer this question making sure that the answer is supposed by the text: {{question}}?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-7":
doc_to_text: "{{text}}\n\nIs the following statement correct based on the text\n\n{{question}}\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-8":
doc_to_text: "{{title}}\n\n{{text}}\n\nIs this statement correct \"{{question}}\"?\n\n{{options_}}"
doc_to_target: "{{answer}}"
"template-9":
doc_to_text: "Is it true that {{question}} based on the following text?\n\n{{text}}\n\n{{options_}}"
doc_to_target: "{{answer}}"
group: flan_held_in
task:
- include: held_in_template_yaml
dataset_path: super_glue
dataset_name: boolq
use_prompt: flan_boolq.yaml:*
validation_split: validation
# - include: held_in_template_yaml
# dataset_path: super_glue
# dataset_name: rte
# use_prompt: local:*
# validation_split: validation
# - include: held_in_template_yaml
# task: anli_r1
# dataset_path: anli
# use_prompt: local:*
# validation_split: dev_r1
# - include: held_in_template_yaml
# task: anli_r2
# dataset_path: anli
# use_prompt: local:*
# validation_split: dev_r2
# - include: held_in_template_yaml
# task: anli_r3
# dataset_path: anli
# use_prompt: local:*
# validation_split: dev_r3
# - include: held_in_template_yaml
# task: ai2_arc
# dataset_path: ARC-Easy
# use_prompt: local:*
# validation_split: validation
# - include: held_in_template_yaml
# task: ai2_arc
# dataset_path: ARC-Challange
# use_prompt: local:*
# validation_split: validation
output_type: greedy_until
validation_split: validation
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
generation_kwargs:
until:
- "\n\n"
do_sample: false
temperature: 0.0
...@@ -44,6 +44,14 @@ def get_prompt(prompt_id: str, dataset_name=None, subset_name=None): ...@@ -44,6 +44,14 @@ def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
raise ValueError( raise ValueError(
f"{prompt_name} not in prompt list {prompts.all_template_names}" f"{prompt_name} not in prompt list {prompts.all_template_names}"
) )
elif ".yaml" in category_name:
import yaml
with open(category_name, "rb") as file:
prompt_yaml_file = yaml.full_load(file)
prompt_string = prompt_yaml_file["prompts"][prompt_name]
return PromptString(prompt_string)
else: else:
try: try:
return PROMPT_REGISTRY[category_name][prompt_name] return PROMPT_REGISTRY[category_name][prompt_name]
...@@ -56,13 +64,42 @@ def get_prompt(prompt_id: str, dataset_name=None, subset_name=None): ...@@ -56,13 +64,42 @@ def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwargs): def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwargs):
from promptsource.templates import DatasetTemplates category_name, prompt_name = use_prompt.split(":")
if subset_name is None: if category_name == "promptsource":
prompts = DatasetTemplates(dataset_name=dataset_name) from promptsource.templates import DatasetTemplates
else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name) if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(
dataset_name=dataset_name, subset_name=subset_name
)
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
elif ".yaml" in category_name:
import yaml
with open(category_name, "rb") as file:
prompt_yaml_file = yaml.full_load(file)
prompt_list = utils.pattern_match(
prompt_name, prompt_yaml_file["prompts"].keys()
)
category_name, prompt_name = use_prompt.split(":")
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
return [":".join([category_name, prompt]) for prompt in prompt_list] return [":".join([category_name, prompt]) for prompt in prompt_list]
class PromptString:
def __init__(prompt_string):
self.prompt_string = prompt_string
def apply(self, doc):
doc_to_text = self.prompt_string["doc_to_text"]
doc_to_target = self.prompt_string["doc_to_target"]
text_string = utils.apply_template(doc_to_text, doc)
target_string = utils.apply_template(doc_to_target, doc)
return [text_string, target_string]
...@@ -412,39 +412,43 @@ def import_function(loader, node): ...@@ -412,39 +412,43 @@ def import_function(loader, node):
yaml.add_constructor("!function", import_function) yaml.add_constructor("!function", import_function)
def load_yaml_config(yaml_path): def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file) if yaml_config is None:
yaml_dir = os.path.dirname(yaml_path) with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
if "include" in yaml_config: yaml_dir = os.path.dirname(yaml_path)
include_path = yaml_config["include"]
del yaml_config["include"] assert yaml_dir is not None
if type(include_path) == str: if "include" in yaml_config:
include_path = [include_path] include_path = yaml_config["include"]
del yaml_config["include"]
# Load from the last one first
include_path.reverse() if type(include_path) == str:
final_yaml_config = {} include_path = [include_path]
for path in include_path:
# Load from the last one first
# Assumes that path is a full path. include_path.reverse()
# If not found, assume the included yaml final_yaml_config = {}
# is in the same dir as the original yaml for path in include_path:
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path) # Assumes that path is a full path.
# If not found, assume the included yaml
try: # is in the same dir as the original yaml
included_yaml_config = load_yaml_config(path) if not os.path.isfile(path):
final_yaml_config.update(included_yaml_config) path = os.path.join(yaml_dir, path)
except Exception as ex:
# If failed to load, ignore try:
raise ex included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config)
final_yaml_config.update(yaml_config) except Exception as ex:
return final_yaml_config # If failed to load, ignore
return yaml_config raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
def regex_replace(string, pattern, repl, count=0): def regex_replace(string, pattern, repl, count=0):
......
...@@ -11,6 +11,7 @@ from lm_eval import evaluator, utils ...@@ -11,6 +11,7 @@ from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
from lm_eval.tasks import include_task_folder from lm_eval.tasks import include_task_folder
from lm_eval.benchmarks import include_benchmarks
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment