Commit 6c110147 authored by Baber's avatar Baber
Browse files

remove redundant jinja parsing logic

parent e66aa10c
...@@ -18,7 +18,6 @@ from typing import ( ...@@ -18,7 +18,6 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
Union, Union,
cast,
) )
import datasets import datasets
...@@ -388,7 +387,7 @@ class Task(abc.ABC): ...@@ -388,7 +387,7 @@ class Task(abc.ABC):
def doc_to_audio(self, doc): def doc_to_audio(self, doc):
raise NotImplementedError raise NotImplementedError
def doc_to_prefix(self, doc): def doc_to_prefix(self, doc) -> Union[str, None]:
return "" return ""
def build_all_requests( def build_all_requests(
...@@ -1288,27 +1287,22 @@ class ConfigurableTask(Task): ...@@ -1288,27 +1287,22 @@ class ConfigurableTask(Task):
return doc return doc
def doc_to_text(self, doc, doc_to_text=None): def doc_to_text(self, doc, doc_to_text=None):
if self.prompt is not None: doc_to_text = doc_to_text or self.config.doc_to_text
doc_to_text = self.prompt
elif doc_to_text is not None:
doc_to_text = doc_to_text
else:
doc_to_text = self.config.doc_to_text
if isinstance(doc_to_text, int): if isinstance(doc_to_text, int):
return doc_to_text return doc_to_text
elif isinstance(doc_to_text, str): elif isinstance(doc_to_text, str):
if doc_to_text in self.features: if doc_to_text in self.features:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]]
# else:
return doc[doc_to_text] return doc[doc_to_text]
else: else:
text_string = utils.apply_template(doc_to_text, doc) text_string = utils.apply_template(doc_to_text, doc)
if text_string.isdigit() and self._config.doc_to_choice is not None: if not self.multiple_inputs:
return ast.literal_eval(text_string)
else:
return text_string return text_string
else:
assert text_string.isdigit(), (
"doc_to_text should be int label for multiple_inputs!"
)
return ast.literal_eval(text_string)
elif callable(doc_to_text): elif callable(doc_to_text):
return doc_to_text(doc) return doc_to_text(doc)
# Used when applying a Promptsource template # Used when applying a Promptsource template
...@@ -1323,39 +1317,21 @@ class ConfigurableTask(Task): ...@@ -1323,39 +1317,21 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]: def doc_to_target(
if self.prompt is not None: self, doc: dict[str, Any], doc_to_target=None
doc_to_target = self.prompt ) -> Union[int, str, list]:
elif doc_to_target is not None: doc_to_target = doc_to_target or self.config.doc_to_target
doc_to_target = doc_to_target
else:
doc_to_target = self.config.doc_to_target
if isinstance(doc_to_target, int): if isinstance(doc_to_target, int) or isinstance(doc_to_target, list):
return doc_to_target return doc_to_target
elif isinstance(doc_to_target, str): elif isinstance(doc_to_target, str):
if doc_to_target in self.features: if doc_to_target in self.features:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]]
# else:
return doc[doc_to_target] return doc[doc_to_target]
target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(target_string)
else: else:
target_string = utils.apply_template(doc_to_target, doc) return target_string
if target_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(target_string)
elif (
len(target_string) >= 2
and (target_string[0] == "[")
and (target_string[-1] == "]")
):
try:
return ast.literal_eval(target_string)
except (SyntaxError, ValueError):
return target_string
else:
return target_string
elif isinstance(doc_to_target, list):
return doc_to_target
elif callable(doc_to_target): elif callable(doc_to_target):
return doc_to_target(doc) return doc_to_target(doc)
# Used when applying a Promptsource template # Used when applying a Promptsource template
...@@ -1370,20 +1346,16 @@ class ConfigurableTask(Task): ...@@ -1370,20 +1346,16 @@ class ConfigurableTask(Task):
raise TypeError raise TypeError
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]: def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
if self.prompt is not None: doc_to_choice = doc_to_choice or self.config.doc_to_choice
doc_to_choice = self.prompt if doc_to_choice is None:
elif doc_to_choice is not None:
doc_to_choice = doc_to_choice
elif self.config.doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config") eval_logger.error("doc_to_choice was called but not set in config")
else:
doc_to_choice = self.config.doc_to_choice
if isinstance(doc_to_choice, str): if isinstance(doc_to_choice, str):
if doc_to_choice in self.features: if doc_to_choice in self.features:
return doc[doc_to_choice] return doc[doc_to_choice]
else: else:
return cast(list, utils.apply_template(doc_to_choice, doc)) # literal_eval for parsing "{{[x, y]}}"
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif isinstance(doc_to_choice, list): elif isinstance(doc_to_choice, list):
return utils.apply_template(doc_to_choice, doc) return utils.apply_template(doc_to_choice, doc)
elif isinstance(doc_to_choice, dict): elif isinstance(doc_to_choice, dict):
...@@ -1441,19 +1413,20 @@ class ConfigurableTask(Task): ...@@ -1441,19 +1413,20 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_prefix(self, doc): def doc_to_prefix(self, doc) -> Union[str, None]:
if (gen_prefix := self.config.gen_prefix) is not None: if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features: return (
return doc[gen_prefix] doc[gen_prefix]
else: if gen_prefix in self.features
return utils.apply_template(gen_prefix, doc) else utils.apply_template(gen_prefix, doc)
)
return None return None
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> Union[List[Instance], Instance]:
apply_chat_template = kwargs.pop("apply_chat_template", False) apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None) chat_template: Union[Callable, None] = kwargs.pop("chat_template", None)
aux_arguments = None aux_arguments = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment