"vscode:/vscode.git/clone" did not exist on "4a0551a85b3b6d99c1470df5c8dee1d9c2ffe248"
Commit b2d16321 authored by lintangsutawika's avatar lintangsutawika
Browse files

update loading prompts

parent 30711873
import os
import ast
from typing import Dict
......@@ -65,7 +66,9 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None
)
def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, file_dir=None, **kwargs):
def load_prompt_list(
use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
):
category_name, prompt_name = use_prompt.split(":")
......@@ -84,8 +87,8 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, file_
elif ".yaml" in category_name:
import yaml
if file_dir is not None:
category_name = os.path.realpath(os.path.join(file_dir, category_name))
if yaml_path is not None:
category_name = os.path.realpath(os.path.join(yaml_path, category_name))
with open(category_name, "rb") as file:
prompt_yaml_file = yaml.full_load(file)
......@@ -94,7 +97,7 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, file_
prompt_name, prompt_yaml_file["prompts"].keys()
)
category_name, *prompt_name = use_prompt.split(":")
# category_name, *prompt_name = use_prompt.split(":")
# TODO allow to multiple prompt naming
# if len(prompt_name) > 1:
# prompt_list = []
......
......@@ -45,24 +45,25 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
task_list = [task for task in all_task_list if type(task) == str]
for task_config in config_list:
# if "task" in task_config:
# task = task_config["task"]
# if task in GROUP_REGISTRY:
# task_list = GROUP_REGISTRY[task]
# elif task in TASK_REGISTRY:
# task_list = [TASK_REGISTRY[task]]
# for _task in task_list:
# task_config = {
# **_task["CONFIG"],
# **task_config
# }
# assert "task" in task_config:
# task = task_config["task"]
# if task in GROUP_REGISTRY:
# task_list = GROUP_REGISTRY[task]
# elif task in TASK_REGISTRY:
# task_list = [TASK_REGISTRY[task]]
# for _task in task_list:
# task_config = {
# **_task["CONFIG"],
# **task_config
# }
task_config = utils.load_yaml_config(yaml_path, task_config)
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
},
yaml_path=os.path.dirname(yaml_path),
)
for config in var_configs:
register_configurable_task(config)
......@@ -79,13 +80,16 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
return 0
def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
def check_prompt_config(
config: Dict[str, str], yaml_path: str = None
) -> List[Dict[str, str]]:
all_configs = []
if "use_prompt" in config:
prompt_list = prompts.load_prompt_list(
use_prompt=config["use_prompt"],
dataset_name=config["dataset_path"],
subset_name=config["dataset_name"] if "dataset_name" in config else None,
yaml_path=yaml_path,
)
for idx, prompt_variation in enumerate(prompt_list):
all_configs.append(
......@@ -98,7 +102,9 @@ def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
config["task"]
if "task" in config
else get_task_name_from_config(config),
prompt_variation,
prompt_variation.split("/")[-1]
if ".yaml" in prompt_variation
else prompt_variation,
]
)
},
......@@ -117,7 +123,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir: str, register_task=True) -> None:
def include_task_folder(task_dir: str, register_task: bool = True) -> None:
"""
Calling this function
"""
......@@ -129,29 +135,33 @@ def include_task_folder(task_dir: str, register_task=True) -> None:
try:
config = utils.load_yaml_config(yaml_path)
# if ("prompts" in config) and (len(config.keys()) == 1):
# continue
if "task" not in config:
continue
if register_task:
all_configs = check_prompt_config(config)
for config in all_configs:
register_configurable_task(config)
else:
# If a `task` in config is a list,
# that means it's a benchmark
if type(config["task"]) == list:
register_configurable_group(config, yaml_path)
all_configs = check_prompt_config(
config, yaml_path=os.path.dirname(yaml_path)
)
for config in all_configs:
if register_task:
if type(config["task"]) == str:
register_configurable_task(config)
else:
if type(config["task"]) == list:
register_configurable_group(config, yaml_path)
except Exception as error:
import traceback
print(traceback.format_exc())
print("###")
print(yaml_path)
eval_logger.warning(
"Failed to load config in\n"
f" {yaml_path}\n"
" Config will not be added to registry\n"
f" Error: {error}"
f" Error: {error}\n"
f" Traceback: {traceback.format_exc()}"
)
return 0
def include_path(task_dir):
......@@ -160,6 +170,7 @@ def include_path(task_dir):
include_task_folder(task_dir, register_task=False)
return 0
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_path(task_dir)
......
......@@ -54,7 +54,7 @@ if __name__ == "__main__":
shot = "Q:" + shot
try:
answer = answer_regex.search(shot)[0]
except:
except Exception:
print("task", task)
print(shot)
example = shot.split("Let's think step by step.")[0]
......
group: flan_anli
task:
- include: flan/yaml_templates/held_in_template_yaml
- include: yaml_templates/held_in_template_yaml
task: anli_r1
dataset_path: anli
use_prompt: flan/prompt_templates/anli.yaml:*
use_prompt: prompt_templates/anli.yaml:*
validation_split: dev_r1
- include: flan/yaml_templates/held_in_template_yaml
- include: yaml_templates/held_in_template_yaml
task: anli_r2
dataset_path: anli
use_prompt: flan/prompt_templates/anli.yaml:*
use_prompt: prompt_templates/anli.yaml:*
validation_split: dev_r2
- include: flan/yaml_templates/held_in_template_yaml
- include: yaml_templates/held_in_template_yaml
task: anli_r3
dataset_path: anli
use_prompt: flan/prompt_templates/anli.yaml:*
use_prompt: prompt_templates/anli.yaml:*
validation_split: dev_r3
group: flan_arc
task:
- include: flan/yaml_templates/held_in_template_yaml
- include: yaml_templates/held_in_template_yaml
task: arc_easy
dataset_path: ai2_arc
dataset_name: ARC-Easy
use_prompt: flan/prompt_templates/arc.yaml:*
use_prompt: prompt_templates/arc.yaml:*
validation_split: validation
- include: flan/yaml_templates/held_in_template_yaml
- include: yaml_templates/held_in_template_yaml
task: arc_challenge
dataset_path: ai2_arc
dataset_name: ARC-Challenge
use_prompt: flan/prompt_templates/arc.yaml:*
use_prompt: prompt_templates/arc.yaml:*
validation_split: validation
group: flan_boolq
task:
- include: flan/yaml_templates/held_in_template_yaml
- include: yaml_templates/held_in_template_yaml
dataset_path: super_glue
dataset_name: boolq
use_prompt: flan/prompt_templates/boolq.yaml:*
use_prompt: prompt_templates/boolq.yaml:*
validation_split: validation
group: flan_cot
task:
- include: flan/yaml_templates/cot_template_yaml
- include: yaml_templates/cot_template_yaml
dataset_path: gsmk
dataset_name: boolq
use_prompt: promptsource:*
validation_split: validation
- include: flan/yaml_templates/cot_template_yaml
- include: yaml_templates/cot_template_yaml
dataset_path: EleutherAI/asdiv
use_prompt: promptsource:*
validation_split: validation
group: flan_rte
task:
- include: flan/yaml_templates/held_in_template_yaml
- include: yaml_templates/held_in_template_yaml
dataset_path: super_glue
dataset_name: rte
use_prompt: flan/prompt_templates/flan_rte.yaml:*
use_prompt: prompt_templates/rte.yaml:*
validation_split: validation
......@@ -426,7 +426,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
if yaml_config is None:
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
if yaml_dir is None:
yaml_dir = os.path.dirname(yaml_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment