Unverified Commit 17191063 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Enable override of printed `n-shot` in table (#1379)

* allow tasks to specify printed fewshot val

* fix to belebele

* update metadata field's documentation
parent 994bdb3f
......@@ -50,7 +50,7 @@ Scoring details:
- **doc_to_decontamination_query** (`str`, *optional*) — Query for decontamination if `should_decontaminate` is True. If `should_decontaminate` is True but `doc_to_decontamination_query` is `None`, `doc_to_decontamination_query` will follow `doc_to_text`.
Other:
- **metadata** (`Union[str, list]`, *optional*) — An optional field where arbitrary metadata can be passed. A good example would be `version` that is used to denote the version of the yaml config.
- **metadata** (`dict`, *optional*) — An optional field where arbitrary metadata can be passed. Most tasks should include a `version` key in this field that is used to denote the version of the yaml config. Other special metadata keys are: `num_fewshot`, to override the printed `n-shot` table column for a task.
## Filters
......
......@@ -86,9 +86,7 @@ class TaskConfig(dict):
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
metadata: Union[
str, list
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
metadata: dict = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
......
......@@ -280,9 +280,12 @@ def evaluate(
configs[task_name] = dict(task.dump_config())
if "num_fewshot" in configs[task_name]:
n_shot = configs[task_name]["num_fewshot"]
if configs[task_name]["metadata"]:
n_shot = configs[task_name]["metadata"].get("num_fewshot", None)
if not n_shot:
n_shot = configs[task_name]["num_fewshot"]
else:
n_shot = 0
n_shot = 0 # TODO: is this always right?
num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]:
......@@ -633,7 +636,7 @@ def evaluate(
for group_name, task_list in task_hierarchy.items():
if task_list != []:
num_fewshot[group_name] = num_fewshot[task_list[0]]
num_fewshot[group_name] = num_fewshot[task_list[0]] # TODO: validate this
results_dict = {
"results": dict(results_agg.items()),
......
......@@ -28,3 +28,4 @@ filter_list:
num_fewshot: 0
metadata:
version: 2.0
num_fewshot: 3 # controls what is printed in n-shot
......@@ -19,3 +19,4 @@ generation_kwargs:
num_fewshot: 0
metadata:
version: 1.0
num_fewshot: 3 # will be printed in results table
......@@ -42,7 +42,7 @@ if __name__ == "__main__":
print(query())
languages = [split["split"] for split in query()]
for lang in tqdm(languages):
for lang in tqdm([lang for lang in languages if "default" not in lang]):
yaml_dict = {
"include": base_yaml_name,
"task": f"belebele_{args.task_prefix}_{lang}"
......
"fewshot_split": "default"
"include": "_default_template_yaml"
"task": "belebele_default"
"test_split": "default"
......@@ -41,3 +41,4 @@ filter_list:
- function: "take_first"
metadata:
version: 2.0
num_fewshot: 8
......@@ -22,3 +22,4 @@ metric_list:
num_fewshot: 0
metadata:
version: 1.0
num_fewshot: 4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment