Commit e66aa10c authored by Baber's avatar Baber
Browse files

fix processing regression; recursively parse lists fron template

parent 3d5fa4c7
...@@ -18,6 +18,7 @@ from typing import ( ...@@ -18,6 +18,7 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
Union, Union,
cast,
) )
import datasets import datasets
...@@ -1382,9 +1383,9 @@ class ConfigurableTask(Task): ...@@ -1382,9 +1383,9 @@ class ConfigurableTask(Task):
if doc_to_choice in self.features: if doc_to_choice in self.features:
return doc[doc_to_choice] return doc[doc_to_choice]
else: else:
return utils.apply_template(doc_to_choice, doc) return cast(list, utils.apply_template(doc_to_choice, doc))
elif isinstance(doc_to_choice, list): elif isinstance(doc_to_choice, list):
return doc_to_choice return utils.apply_template(doc_to_choice, doc)
elif isinstance(doc_to_choice, dict): elif isinstance(doc_to_choice, dict):
return list(doc_to_choice.values()) return list(doc_to_choice.values())
elif callable(doc_to_choice): elif callable(doc_to_choice):
...@@ -1606,8 +1607,8 @@ class ConfigurableTask(Task): ...@@ -1606,8 +1607,8 @@ class ConfigurableTask(Task):
pred = np.argmax(lls) pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len) pred_norm = np.argmax(lls / completion_len)
gold = ( gold = backup = (
self.doc_to_text(doc) self.doc_to_target(doc)
if not self.multiple_inputs if not self.multiple_inputs
else self.doc_to_text(doc) else self.doc_to_text(doc)
) )
...@@ -1625,7 +1626,7 @@ class ConfigurableTask(Task): ...@@ -1625,7 +1626,7 @@ class ConfigurableTask(Task):
if gold_index_error: if gold_index_error:
eval_logger.warning( eval_logger.warning(
f"Label index was not in within range of available choices," f"Label [{backup}] index was not in within range of available choices {choices},"
f"Sample:\n\n{doc}\n\n" f"Sample:\n\n{doc}\n\n"
) )
......
...@@ -8,7 +8,7 @@ training_split: train ...@@ -8,7 +8,7 @@ training_split: train
validation_split: validation validation_split: validation
doc_to_text: !function utils.doc_to_text doc_to_text: !function utils.doc_to_text
doc_to_target: label doc_to_target: label
doc_to_choice: "{{ [choice1, choice2] }}" doc_to_choice: ["{{ choice1 }}", "{{ choice2 }}"]
metric_list: metric_list:
- metric: acc - metric: acc
metadata: metadata:
......
...@@ -11,7 +11,7 @@ import re ...@@ -11,7 +11,7 @@ import re
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple from typing import Any, Callable, Generator, List, Optional, Tuple, Union, overload
import numpy as np import numpy as np
import yaml import yaml
...@@ -545,7 +545,20 @@ env = Environment( ...@@ -545,7 +545,20 @@ env = Environment(
env.filters["regex_replace"] = regex_replace env.filters["regex_replace"] = regex_replace
def apply_template(template: str, doc: dict) -> str: @overload
def apply_template(template: str, doc: dict[str, Any]) -> str: ...
@overload
def apply_template(template: list[str], doc: dict[str, Any]) -> list[str]: ...
def apply_template(template: Union[str, list[str]], doc: dict) -> Union[str, list[str]]:
if isinstance(template, list):
return [
apply_template(x, doc) if (x.startswith("{{") and x.endswith("}}")) else x
for x in template
]
rtemplate = env.from_string(template) rtemplate = env.from_string(template)
return rtemplate.render(**doc) return rtemplate.render(**doc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment