Commit e66aa10c authored by Baber's avatar Baber
Browse files

fix processing regression; recursively parse lists fron template

parent 3d5fa4c7
......@@ -18,6 +18,7 @@ from typing import (
Optional,
Tuple,
Union,
cast,
)
import datasets
......@@ -1382,9 +1383,9 @@ class ConfigurableTask(Task):
if doc_to_choice in self.features:
return doc[doc_to_choice]
else:
return utils.apply_template(doc_to_choice, doc)
return cast(list, utils.apply_template(doc_to_choice, doc))
elif isinstance(doc_to_choice, list):
return doc_to_choice
return utils.apply_template(doc_to_choice, doc)
elif isinstance(doc_to_choice, dict):
return list(doc_to_choice.values())
elif callable(doc_to_choice):
......@@ -1606,8 +1607,8 @@ class ConfigurableTask(Task):
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
gold = (
self.doc_to_text(doc)
gold = backup = (
self.doc_to_target(doc)
if not self.multiple_inputs
else self.doc_to_text(doc)
)
......@@ -1625,7 +1626,7 @@ class ConfigurableTask(Task):
if gold_index_error:
eval_logger.warning(
f"Label index was not in within range of available choices,"
f"Label [{backup}] index was not in within range of available choices {choices},"
f"Sample:\n\n{doc}\n\n"
)
......
......@@ -8,7 +8,7 @@ training_split: train
validation_split: validation
doc_to_text: !function utils.doc_to_text
doc_to_target: label
doc_to_choice: "{{ [choice1, choice2] }}"
doc_to_choice: ["{{ choice1 }}", "{{ choice2 }}"]
metric_list:
- metric: acc
metadata:
......
......@@ -11,7 +11,7 @@ import re
from dataclasses import asdict, is_dataclass
from itertools import islice
from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple
from typing import Any, Callable, Generator, List, Optional, Tuple, Union, overload
import numpy as np
import yaml
......@@ -545,7 +545,20 @@ env = Environment(
env.filters["regex_replace"] = regex_replace
def apply_template(template: str, doc: dict) -> str:
@overload
def apply_template(template: str, doc: dict[str, Any]) -> str: ...
@overload
def apply_template(template: list[str], doc: dict[str, Any]) -> list[str]: ...
def apply_template(template: Union[str, list[str]], doc: dict) -> Union[str, list[str]]:
if isinstance(template, list):
return [
apply_template(x, doc) if (x.startswith("{{") and x.endswith("}}")) else x
for x in template
]
rtemplate = env.from_string(template)
return rtemplate.render(**doc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment