Commit 6c110147 authored by Baber's avatar Baber
Browse files

remove redundant jinja parsing logic

parent e66aa10c
......@@ -18,7 +18,6 @@ from typing import (
Optional,
Tuple,
Union,
cast,
)
import datasets
......@@ -388,7 +387,7 @@ class Task(abc.ABC):
def doc_to_audio(self, doc):
raise NotImplementedError
def doc_to_prefix(self, doc):
def doc_to_prefix(self, doc) -> Union[str, None]:
return ""
def build_all_requests(
......@@ -1288,27 +1287,22 @@ class ConfigurableTask(Task):
return doc
def doc_to_text(self, doc, doc_to_text=None):
if self.prompt is not None:
doc_to_text = self.prompt
elif doc_to_text is not None:
doc_to_text = doc_to_text
else:
doc_to_text = self.config.doc_to_text
doc_to_text = doc_to_text or self.config.doc_to_text
if isinstance(doc_to_text, int):
return doc_to_text
elif isinstance(doc_to_text, str):
if doc_to_text in self.features:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]]
# else:
return doc[doc_to_text]
else:
text_string = utils.apply_template(doc_to_text, doc)
if text_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(text_string)
else:
if not self.multiple_inputs:
return text_string
else:
assert text_string.isdigit(), (
"doc_to_text should be int label for multiple_inputs!"
)
return ast.literal_eval(text_string)
elif callable(doc_to_text):
return doc_to_text(doc)
# Used when applying a Promptsource template
......@@ -1323,39 +1317,21 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
if self.prompt is not None:
doc_to_target = self.prompt
elif doc_to_target is not None:
doc_to_target = doc_to_target
else:
doc_to_target = self.config.doc_to_target
def doc_to_target(
self, doc: dict[str, Any], doc_to_target=None
) -> Union[int, str, list]:
doc_to_target = doc_to_target or self.config.doc_to_target
if isinstance(doc_to_target, int):
if isinstance(doc_to_target, int) or isinstance(doc_to_target, list):
return doc_to_target
elif isinstance(doc_to_target, str):
if doc_to_target in self.features:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]]
# else:
return doc[doc_to_target]
target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(target_string)
else:
target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(target_string)
elif (
len(target_string) >= 2
and (target_string[0] == "[")
and (target_string[-1] == "]")
):
try:
return ast.literal_eval(target_string)
except (SyntaxError, ValueError):
return target_string
else:
return target_string
elif isinstance(doc_to_target, list):
return doc_to_target
return target_string
elif callable(doc_to_target):
return doc_to_target(doc)
# Used when applying a Promptsource template
......@@ -1370,20 +1346,16 @@ class ConfigurableTask(Task):
raise TypeError
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
if self.prompt is not None:
doc_to_choice = self.prompt
elif doc_to_choice is not None:
doc_to_choice = doc_to_choice
elif self.config.doc_to_choice is None:
doc_to_choice = doc_to_choice or self.config.doc_to_choice
if doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config")
else:
doc_to_choice = self.config.doc_to_choice
if isinstance(doc_to_choice, str):
if doc_to_choice in self.features:
return doc[doc_to_choice]
else:
return cast(list, utils.apply_template(doc_to_choice, doc))
# literal_eval for parsing "{{[x, y]}}"
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif isinstance(doc_to_choice, list):
return utils.apply_template(doc_to_choice, doc)
elif isinstance(doc_to_choice, dict):
......@@ -1441,19 +1413,20 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_prefix(self, doc):
def doc_to_prefix(self, doc) -> Union[str, None]:
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
return doc[gen_prefix]
else:
return utils.apply_template(gen_prefix, doc)
return (
doc[gen_prefix]
if gen_prefix in self.features
else utils.apply_template(gen_prefix, doc)
)
return None
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None)
chat_template: Union[Callable, None] = kwargs.pop("chat_template", None)
aux_arguments = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment