Unverified Commit 17191063 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Enable override of printed `n-shot` in table (#1379)

* allow tasks to specify printed fewshot val

* fix to belebele

* update metadata field's documentation
parent 994bdb3f
...@@ -50,7 +50,7 @@ Scoring details: ...@@ -50,7 +50,7 @@ Scoring details:
- **doc_to_decontamination_query** (`str`, *optional*) — Query for decontamination if `should_decontaminate` is True. If `should_decontaminate` is True but `doc_to_decontamination_query` is `None`, `doc_to_decontamination_query` will follow `doc_to_text`. - **doc_to_decontamination_query** (`str`, *optional*) — Query for decontamination if `should_decontaminate` is True. If `should_decontaminate` is True but `doc_to_decontamination_query` is `None`, `doc_to_decontamination_query` will follow `doc_to_text`.
Other: Other:
- **metadata** (`Union[str, list]`, *optional*) — An optional field where arbitrary metadata can be passed. A good example would be `version` that is used to denote the version of the yaml config. - **metadata** (`dict`, *optional*) — An optional field where arbitrary metadata can be passed. Most tasks should include a `version` key in this field that is used to denote the version of the yaml config. Other special metadata keys are: `num_fewshot`, to override the printed `n-shot` table column for a task.
## Filters ## Filters
......
...@@ -86,9 +86,7 @@ class TaskConfig(dict): ...@@ -86,9 +86,7 @@ class TaskConfig(dict):
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: str = None doc_to_decontamination_query: str = None
metadata: Union[ metadata: dict = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
str, list
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
......
...@@ -280,9 +280,12 @@ def evaluate( ...@@ -280,9 +280,12 @@ def evaluate(
configs[task_name] = dict(task.dump_config()) configs[task_name] = dict(task.dump_config())
if "num_fewshot" in configs[task_name]: if "num_fewshot" in configs[task_name]:
n_shot = configs[task_name]["num_fewshot"] if configs[task_name]["metadata"]:
n_shot = configs[task_name]["metadata"].get("num_fewshot", None)
if not n_shot:
n_shot = configs[task_name]["num_fewshot"]
else: else:
n_shot = 0 n_shot = 0 # TODO: is this always right?
num_fewshot[task_name] = n_shot num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]: if "task_alias" in configs[task_name]:
...@@ -633,7 +636,7 @@ def evaluate( ...@@ -633,7 +636,7 @@ def evaluate(
for group_name, task_list in task_hierarchy.items(): for group_name, task_list in task_hierarchy.items():
if task_list != []: if task_list != []:
num_fewshot[group_name] = num_fewshot[task_list[0]] num_fewshot[group_name] = num_fewshot[task_list[0]] # TODO: validate this
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
......
...@@ -28,3 +28,4 @@ filter_list: ...@@ -28,3 +28,4 @@ filter_list:
num_fewshot: 0 num_fewshot: 0
metadata: metadata:
version: 2.0 version: 2.0
num_fewshot: 3 # controls what is printed in n-shot
...@@ -19,3 +19,4 @@ generation_kwargs: ...@@ -19,3 +19,4 @@ generation_kwargs:
num_fewshot: 0 num_fewshot: 0
metadata: metadata:
version: 1.0 version: 1.0
num_fewshot: 3 # will be printed in results table
...@@ -42,7 +42,7 @@ if __name__ == "__main__": ...@@ -42,7 +42,7 @@ if __name__ == "__main__":
print(query()) print(query())
languages = [split["split"] for split in query()] languages = [split["split"] for split in query()]
for lang in tqdm(languages): for lang in tqdm([lang for lang in languages if "default" not in lang]):
yaml_dict = { yaml_dict = {
"include": base_yaml_name, "include": base_yaml_name,
"task": f"belebele_{args.task_prefix}_{lang}" "task": f"belebele_{args.task_prefix}_{lang}"
......
"fewshot_split": "default"
"include": "_default_template_yaml"
"task": "belebele_default"
"test_split": "default"
...@@ -41,3 +41,4 @@ filter_list: ...@@ -41,3 +41,4 @@ filter_list:
- function: "take_first" - function: "take_first"
metadata: metadata:
version: 2.0 version: 2.0
num_fewshot: 8
...@@ -22,3 +22,4 @@ metric_list: ...@@ -22,3 +22,4 @@ metric_list:
num_fewshot: 0 num_fewshot: 0
metadata: metadata:
version: 1.0 version: 1.0
num_fewshot: 4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment