Commit b2d16321 authored by lintangsutawika's avatar lintangsutawika
Browse files

update loading prompts

parent 30711873
import os
import ast import ast
from typing import Dict from typing import Dict
...@@ -65,7 +66,9 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None ...@@ -65,7 +66,9 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None
) )
def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, file_dir=None, **kwargs): def load_prompt_list(
use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
):
category_name, prompt_name = use_prompt.split(":") category_name, prompt_name = use_prompt.split(":")
...@@ -84,8 +87,8 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, file_ ...@@ -84,8 +87,8 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, file_
elif ".yaml" in category_name: elif ".yaml" in category_name:
import yaml import yaml
if file_dir is not None: if yaml_path is not None:
category_name = os.path.realpath(os.path.join(file_dir, category_name)) category_name = os.path.realpath(os.path.join(yaml_path, category_name))
with open(category_name, "rb") as file: with open(category_name, "rb") as file:
prompt_yaml_file = yaml.full_load(file) prompt_yaml_file = yaml.full_load(file)
...@@ -94,7 +97,7 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, file_ ...@@ -94,7 +97,7 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, file_
prompt_name, prompt_yaml_file["prompts"].keys() prompt_name, prompt_yaml_file["prompts"].keys()
) )
category_name, *prompt_name = use_prompt.split(":") # category_name, *prompt_name = use_prompt.split(":")
# TODO allow to multiple prompt naming # TODO allow to multiple prompt naming
# if len(prompt_name) > 1: # if len(prompt_name) > 1:
# prompt_list = [] # prompt_list = []
......
...@@ -45,7 +45,7 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) - ...@@ -45,7 +45,7 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
task_list = [task for task in all_task_list if type(task) == str] task_list = [task for task in all_task_list if type(task) == str]
for task_config in config_list: for task_config in config_list:
# if "task" in task_config: # assert "task" in task_config:
# task = task_config["task"] # task = task_config["task"]
# if task in GROUP_REGISTRY: # if task in GROUP_REGISTRY:
# task_list = GROUP_REGISTRY[task] # task_list = GROUP_REGISTRY[task]
...@@ -62,7 +62,8 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) - ...@@ -62,7 +62,8 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
{ {
**task_config, **task_config,
**{"group": group}, **{"group": group},
} },
yaml_path=os.path.dirname(yaml_path),
) )
for config in var_configs: for config in var_configs:
register_configurable_task(config) register_configurable_task(config)
...@@ -79,13 +80,16 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) - ...@@ -79,13 +80,16 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
return 0 return 0
def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]: def check_prompt_config(
config: Dict[str, str], yaml_path: str = None
) -> List[Dict[str, str]]:
all_configs = [] all_configs = []
if "use_prompt" in config: if "use_prompt" in config:
prompt_list = prompts.load_prompt_list( prompt_list = prompts.load_prompt_list(
use_prompt=config["use_prompt"], use_prompt=config["use_prompt"],
dataset_name=config["dataset_path"], dataset_name=config["dataset_path"],
subset_name=config["dataset_name"] if "dataset_name" in config else None, subset_name=config["dataset_name"] if "dataset_name" in config else None,
yaml_path=yaml_path,
) )
for idx, prompt_variation in enumerate(prompt_list): for idx, prompt_variation in enumerate(prompt_list):
all_configs.append( all_configs.append(
...@@ -98,7 +102,9 @@ def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]: ...@@ -98,7 +102,9 @@ def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
config["task"] config["task"]
if "task" in config if "task" in config
else get_task_name_from_config(config), else get_task_name_from_config(config),
prompt_variation, prompt_variation.split("/")[-1]
if ".yaml" in prompt_variation
else prompt_variation,
] ]
) )
}, },
...@@ -117,7 +123,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str: ...@@ -117,7 +123,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir: str, register_task=True) -> None: def include_task_folder(task_dir: str, register_task: bool = True) -> None:
""" """
Calling this function Calling this function
""" """
...@@ -129,29 +135,33 @@ def include_task_folder(task_dir: str, register_task=True) -> None: ...@@ -129,29 +135,33 @@ def include_task_folder(task_dir: str, register_task=True) -> None:
try: try:
config = utils.load_yaml_config(yaml_path) config = utils.load_yaml_config(yaml_path)
# if ("prompts" in config) and (len(config.keys()) == 1): if "task" not in config:
continue
# continue
if register_task: all_configs = check_prompt_config(
all_configs = check_prompt_config(config) config, yaml_path=os.path.dirname(yaml_path)
)
for config in all_configs: for config in all_configs:
if register_task:
if type(config["task"]) == str:
register_configurable_task(config) register_configurable_task(config)
else: else:
# If a `task` in config is a list,
# that means it's a benchmark
if type(config["task"]) == list: if type(config["task"]) == list:
register_configurable_group(config, yaml_path) register_configurable_group(config, yaml_path)
except Exception as error: except Exception as error:
import traceback import traceback
print(traceback.format_exc())
print("###")
print(yaml_path)
eval_logger.warning( eval_logger.warning(
"Failed to load config in\n" "Failed to load config in\n"
f" {yaml_path}\n" f" {yaml_path}\n"
" Config will not be added to registry\n" " Config will not be added to registry\n"
f" Error: {error}" f" Error: {error}\n"
f" Traceback: {traceback.format_exc()}"
) )
return 0
def include_path(task_dir): def include_path(task_dir):
...@@ -160,6 +170,7 @@ def include_path(task_dir): ...@@ -160,6 +170,7 @@ def include_path(task_dir):
include_task_folder(task_dir, register_task=False) include_task_folder(task_dir, register_task=False)
return 0 return 0
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_path(task_dir) include_path(task_dir)
......
...@@ -54,7 +54,7 @@ if __name__ == "__main__": ...@@ -54,7 +54,7 @@ if __name__ == "__main__":
shot = "Q:" + shot shot = "Q:" + shot
try: try:
answer = answer_regex.search(shot)[0] answer = answer_regex.search(shot)[0]
except: except Exception:
print("task", task) print("task", task)
print(shot) print(shot)
example = shot.split("Let's think step by step.")[0] example = shot.split("Let's think step by step.")[0]
......
group: flan_anli group: flan_anli
task: task:
- include: flan/yaml_templates/held_in_template_yaml - include: yaml_templates/held_in_template_yaml
task: anli_r1 task: anli_r1
dataset_path: anli dataset_path: anli
use_prompt: flan/prompt_templates/anli.yaml:* use_prompt: prompt_templates/anli.yaml:*
validation_split: dev_r1 validation_split: dev_r1
- include: flan/yaml_templates/held_in_template_yaml - include: yaml_templates/held_in_template_yaml
task: anli_r2 task: anli_r2
dataset_path: anli dataset_path: anli
use_prompt: flan/prompt_templates/anli.yaml:* use_prompt: prompt_templates/anli.yaml:*
validation_split: dev_r2 validation_split: dev_r2
- include: flan/yaml_templates/held_in_template_yaml - include: yaml_templates/held_in_template_yaml
task: anli_r3 task: anli_r3
dataset_path: anli dataset_path: anli
use_prompt: flan/prompt_templates/anli.yaml:* use_prompt: prompt_templates/anli.yaml:*
validation_split: dev_r3 validation_split: dev_r3
group: flan_arc group: flan_arc
task: task:
- include: flan/yaml_templates/held_in_template_yaml - include: yaml_templates/held_in_template_yaml
task: arc_easy task: arc_easy
dataset_path: ai2_arc dataset_path: ai2_arc
dataset_name: ARC-Easy dataset_name: ARC-Easy
use_prompt: flan/prompt_templates/arc.yaml:* use_prompt: prompt_templates/arc.yaml:*
validation_split: validation validation_split: validation
- include: flan/yaml_templates/held_in_template_yaml - include: yaml_templates/held_in_template_yaml
task: arc_challenge task: arc_challenge
dataset_path: ai2_arc dataset_path: ai2_arc
dataset_name: ARC-Challenge dataset_name: ARC-Challenge
use_prompt: flan/prompt_templates/arc.yaml:* use_prompt: prompt_templates/arc.yaml:*
validation_split: validation validation_split: validation
group: flan_boolq group: flan_boolq
task: task:
- include: flan/yaml_templates/held_in_template_yaml - include: yaml_templates/held_in_template_yaml
dataset_path: super_glue dataset_path: super_glue
dataset_name: boolq dataset_name: boolq
use_prompt: flan/prompt_templates/boolq.yaml:* use_prompt: prompt_templates/boolq.yaml:*
validation_split: validation validation_split: validation
group: flan_cot group: flan_cot
task: task:
- include: flan/yaml_templates/cot_template_yaml - include: yaml_templates/cot_template_yaml
dataset_path: gsmk dataset_path: gsmk
dataset_name: boolq dataset_name: boolq
use_prompt: promptsource:* use_prompt: promptsource:*
validation_split: validation validation_split: validation
- include: flan/yaml_templates/cot_template_yaml - include: yaml_templates/cot_template_yaml
dataset_path: EleutherAI/asdiv dataset_path: EleutherAI/asdiv
use_prompt: promptsource:* use_prompt: promptsource:*
validation_split: validation validation_split: validation
group: flan_rte group: flan_rte
task: task:
- include: flan/yaml_templates/held_in_template_yaml - include: yaml_templates/held_in_template_yaml
dataset_path: super_glue dataset_path: super_glue
dataset_name: rte dataset_name: rte
use_prompt: flan/prompt_templates/flan_rte.yaml:* use_prompt: prompt_templates/rte.yaml:*
validation_split: validation validation_split: validation
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment