Commit 3f090027 authored by lintangsutawika's avatar lintangsutawika
Browse files

moved files

parent a5e93901
...@@ -65,7 +65,7 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None ...@@ -65,7 +65,7 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None
) )
def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwargs): def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, file_dir=None, **kwargs):
category_name, prompt_name = use_prompt.split(":") category_name, prompt_name = use_prompt.split(":")
...@@ -84,6 +84,9 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa ...@@ -84,6 +84,9 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
elif ".yaml" in category_name: elif ".yaml" in category_name:
import yaml import yaml
if file_dir is not None:
category_name = os.path.realpath(os.path.join(file_dir, category_name))
with open(category_name, "rb") as file: with open(category_name, "rb") as file:
prompt_yaml_file = yaml.full_load(file) prompt_yaml_file = yaml.full_load(file)
...@@ -98,7 +101,7 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa ...@@ -98,7 +101,7 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
# for prompt in prompt_name: # for prompt in prompt_name:
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names)) # prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
# else: # else:
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) # prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
return [":".join([category_name, prompt]) for prompt in prompt_list] return [":".join([category_name, prompt]) for prompt in prompt_list]
......
...@@ -38,7 +38,7 @@ def register_configurable_task(config: Dict[str, str]) -> int: ...@@ -38,7 +38,7 @@ def register_configurable_task(config: Dict[str, str]) -> int:
return 0 return 0
def register_configurable_group(config: Dict[str, str]) -> int: def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -> int:
group = config["group"] group = config["group"]
all_task_list = config["task"] all_task_list = config["task"]
config_list = [task for task in all_task_list if type(task) != str] config_list = [task for task in all_task_list if type(task) != str]
...@@ -57,6 +57,7 @@ def register_configurable_group(config: Dict[str, str]) -> int: ...@@ -57,6 +57,7 @@ def register_configurable_group(config: Dict[str, str]) -> int:
# **_task["CONFIG"], # **_task["CONFIG"],
# **task_config # **task_config
# } # }
task_config = utils.load_yaml_config(yaml_path, task_config)
var_configs = check_prompt_config( var_configs = check_prompt_config(
{ {
**task_config, **task_config,
...@@ -128,6 +129,10 @@ def include_task_folder(task_dir: str, register_task=True) -> None: ...@@ -128,6 +129,10 @@ def include_task_folder(task_dir: str, register_task=True) -> None:
try: try:
config = utils.load_yaml_config(yaml_path) config = utils.load_yaml_config(yaml_path)
# if ("prompts" in config) and (len(config.keys()) == 1):
# continue
if register_task: if register_task:
all_configs = check_prompt_config(config) all_configs = check_prompt_config(config)
for config in all_configs: for config in all_configs:
...@@ -136,9 +141,11 @@ def include_task_folder(task_dir: str, register_task=True) -> None: ...@@ -136,9 +141,11 @@ def include_task_folder(task_dir: str, register_task=True) -> None:
# If a `task` in config is a list, # If a `task` in config is a list,
# that means it's a benchmark # that means it's a benchmark
if type(config["task"]) == list: if type(config["task"]) == list:
register_configurable_group(config) register_configurable_group(config, yaml_path)
except Exception as error: except Exception as error:
import traceback
print(traceback.format_exc())
eval_logger.warning( eval_logger.warning(
"Failed to load config in\n" "Failed to load config in\n"
f" {yaml_path}\n" f" {yaml_path}\n"
......
...@@ -3,15 +3,15 @@ task: ...@@ -3,15 +3,15 @@ task:
- include: flan/yaml_templates/held_in_template_yaml - include: flan/yaml_templates/held_in_template_yaml
task: anli_r1 task: anli_r1
dataset_path: anli dataset_path: anli
use_prompt: flan/prompt_templates/flan_anli.yaml:* use_prompt: flan/prompt_templates/anli.yaml:*
validation_split: dev_r1 validation_split: dev_r1
- include: flan/yaml_templates/held_in_template_yaml - include: flan/yaml_templates/held_in_template_yaml
task: anli_r2 task: anli_r2
dataset_path: anli dataset_path: anli
use_prompt: flan/prompt_templates/flan_anli.yaml:* use_prompt: flan/prompt_templates/anli.yaml:*
validation_split: dev_r2 validation_split: dev_r2
- include: flan/yaml_templates/held_in_template_yaml - include: flan/yaml_templates/held_in_template_yaml
task: anli_r3 task: anli_r3
dataset_path: anli dataset_path: anli
use_prompt: flan/prompt_templates/flan_anli.yaml:* use_prompt: flan/prompt_templates/anli.yaml:*
validation_split: dev_r3 validation_split: dev_r3
group: flan_arc
task:
- include: flan/yaml_templates/held_in_template_yaml
task: arc_easy
dataset_path: ai2_arc
dataset_name: ARC-Easy
use_prompt: flan/prompt_templates/arc.yaml:*
validation_split: validation
- include: flan/yaml_templates/held_in_template_yaml
task: arc_challenge
dataset_path: ai2_arc
dataset_name: ARC-Challenge
use_prompt: flan/prompt_templates/arc.yaml:*
validation_split: validation
...@@ -3,5 +3,5 @@ task: ...@@ -3,5 +3,5 @@ task:
- include: flan/yaml_templates/held_in_template_yaml - include: flan/yaml_templates/held_in_template_yaml
dataset_path: super_glue dataset_path: super_glue
dataset_name: boolq dataset_name: boolq
use_prompt: flan/prompt_templates/flan_boolq.yaml:* use_prompt: flan/prompt_templates/boolq.yaml:*
validation_split: validation validation_split: validation
group: flan_held_in
task:
- flan_boolq
- flan_rte
- flan_anli
- flan_arc
...@@ -3,37 +3,37 @@ task: ...@@ -3,37 +3,37 @@ task:
- include: flan/yaml_templates/held_in_template_yaml - include: flan/yaml_templates/held_in_template_yaml
dataset_path: super_glue dataset_path: super_glue
dataset_name: boolq dataset_name: boolq
use_prompt: flan/prompt_templates/flan_boolq.yaml:* use_prompt: flan/prompt_templates/boolq.yaml:*
validation_split: validation validation_split: validation
- include: flan/yaml_templates/held_in_template_yaml - include: flan/yaml_templates/held_in_template_yaml
dataset_path: super_glue dataset_path: super_glue
dataset_name: rte dataset_name: rte
use_prompt: flan/prompt_templates/flan_rte.yaml:* use_prompt: flan/prompt_templates/rte.yaml:*
validation_split: validation validation_split: validation
- include: flan/yaml_templates/held_in_template_yaml - include: flan/yaml_templates/held_in_template_yaml
task: anli_r1 task: anli_r1
dataset_path: anli dataset_path: anli
use_prompt: flan/prompt_templates/flan_anli.yaml:* use_prompt: flan/prompt_templates/anli.yaml:*
validation_split: dev_r1 validation_split: dev_r1
- include: flan/yaml_templates/held_in_template_yaml - include: flan/yaml_templates/held_in_template_yaml
task: anli_r2 task: anli_r2
dataset_path: anli dataset_path: anli
use_prompt: flan/prompt_templates/flan_anli.yaml:* use_prompt: flan/prompt_templates/anli.yaml:*
validation_split: dev_r2 validation_split: dev_r2
- include: flan/yaml_templates/held_in_template_yaml - include: flan/yaml_templates/held_in_template_yaml
task: anli_r3 task: anli_r3
dataset_path: anli dataset_path: anli
use_prompt: flan/prompt_templates/flan_anli.yaml:* use_prompt: flan/prompt_templates/anli.yaml:*
validation_split: dev_r3 validation_split: dev_r3
- include: flan/yaml_templates/held_in_template_yaml - include: flan/yaml_templates/held_in_template_yaml
task: arc_easy task: arc_easy
dataset_path: ai2_arc dataset_path: ai2_arc
dataset_name: ARC-Easy dataset_name: ARC-Easy
use_prompt: flan/prompt_templates/flan_arc.yaml:* use_prompt: flan/prompt_templates/arc.yaml:*
validation_split: validation validation_split: validation
- include: flan/yaml_templates/held_in_template_yaml - include: flan/yaml_templates/held_in_template_yaml
task: arc_challenge task: arc_challenge
dataset_path: ai2_arc dataset_path: ai2_arc
dataset_name: ARC-Challenge dataset_name: ARC-Challenge
use_prompt: flan/prompt_templates/flan_arc.yaml:* use_prompt: flan/prompt_templates/arc.yaml:*
validation_split: validation validation_split: validation
...@@ -426,7 +426,9 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): ...@@ -426,7 +426,9 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
if yaml_config is None: if yaml_config is None:
with open(yaml_path, "rb") as file: with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file) yaml_config = yaml.full_load(file)
yaml_dir = os.path.dirname(yaml_path)
if yaml_dir is None:
yaml_dir = os.path.dirname(yaml_path)
assert yaml_dir is not None assert yaml_dir is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment