Commit 1646596a authored by Baber's avatar Baber
Browse files

update template

parent 29cb9ecb
...@@ -18,7 +18,7 @@ class TemplateConfig(ABC): ...@@ -18,7 +18,7 @@ class TemplateConfig(ABC):
# #
template: str template: str
task: str task: str
doc_to_text: str | Callable[[dict], str] doc_to_text: str | Callable[[dict], str] | list[str]
doc_to_choice: str | list | Callable[[dict], list] doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int] doc_to_target: int | Callable[[dict], int]
description: str description: str
...@@ -49,7 +49,7 @@ class TemplateConfig(ABC): ...@@ -49,7 +49,7 @@ class TemplateConfig(ABC):
@dataclass @dataclass
class MCQTemplateConfig(TemplateConfig): class MCQTemplateConfig:
"""Encapsulates information about a template. """Encapsulates information about a template.
Would return a sample with the following format: Would return a sample with the following format:
Question: <doc_to_text(doc)> Question: <doc_to_text(doc)>
...@@ -57,11 +57,11 @@ class MCQTemplateConfig(TemplateConfig): ...@@ -57,11 +57,11 @@ class MCQTemplateConfig(TemplateConfig):
B. <doc_to_choice(doc)[1]> B. <doc_to_choice(doc)[1]>
C. <doc_to_choice(doc)[2]> C. <doc_to_choice(doc)[2]>
D. <doc_to_choice(doc)[3]> D. <doc_to_choice(doc)[3]>
Answer:` doc_to_choice(doc)` for each choice. Answer: 'doc_to_choice(doc)` for each choice.
""" """
doc_to_text: str | Callable[[dict], str] doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list] doc_to_choice: list[str]
doc_to_target: int | Callable[[dict], int] doc_to_target: int | Callable[[dict], int]
template = "mcq" template = "mcq"
context_prefix: str = "Question:" context_prefix: str = "Question:"
...@@ -70,18 +70,27 @@ class MCQTemplateConfig(TemplateConfig): ...@@ -70,18 +70,27 @@ class MCQTemplateConfig(TemplateConfig):
answer_suffix: str = "Answer:" answer_suffix: str = "Answer:"
target_delimiter: str = "\n" target_delimiter: str = "\n"
choice_format: str | None = "letters" choice_format: str | None = "letters"
choice_delimiter: str | None = "\n" choice_delimiter: str = "\n"
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
metric_list: list[MetricConfig] | None = field(default_factory=lambda: ["acc"]) metric_list: list[MetricConfig] | None = field(default_factory=lambda: ["acc"])
def _doc_to_text(self, doc: dict) -> str: def _doc_to_text(self, doc: dict) -> str:
"""Convert a document to text.""" """Convert a document to text."""
doc_to_text = ( doc_to_text: str = (
self.doc_to_text self.doc_to_text
if isinstance(self.doc_to_text, str) if isinstance(self.doc_to_text, str)
else self.doc_to_text(doc) else self.doc_to_text(doc)
) )
return self.context_prefix + doc_to_text return (
self.context_prefix
+ self.prefix_delimiter
+ doc_to_text
+ self.context_delimiter
+ create_mc_choices(
self.doc_to_choice, choice_delimiter=self.choice_delimiter
)
+ self.answer_suffix
)
def _doc_to_choice(self, doc: dict) -> str: def _doc_to_choice(self, doc: dict) -> str:
if callable(self.doc_to_choice): if callable(self.doc_to_choice):
...@@ -111,7 +120,7 @@ class ClozeTemplateConfig(TemplateConfig): ...@@ -111,7 +120,7 @@ class ClozeTemplateConfig(TemplateConfig):
""" """
doc_to_text: str | Callable[[dict], str] doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list] doc_to_choice: list[str]
doc_to_target: int | Callable[[dict], int] doc_to_target: int | Callable[[dict], int]
template: str = "cloze" template: str = "cloze"
description: str = "" description: str = ""
...@@ -121,8 +130,41 @@ class ClozeTemplateConfig(TemplateConfig): ...@@ -121,8 +130,41 @@ class ClozeTemplateConfig(TemplateConfig):
answer_suffix: str = "Answer:" answer_suffix: str = "Answer:"
target_delimiter: str = " " target_delimiter: str = " "
choice_format: str | None = None choice_format: str | None = None
choice_delimiter: str | None = None choice_delimiter: str = ""
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
metric_list: list[MetricConfig] | None = field( metric_list: list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"] default_factory=lambda: ["acc", "acc_norm"]
) )
def _doc_to_text(self, doc: dict) -> str:
"""Convert a document to text."""
doc_to_text: str = (
self.doc_to_text
if isinstance(self.doc_to_text, str)
else self.doc_to_text(doc)
)
return (
self.context_prefix
+ self.prefix_delimiter
+ doc_to_text
+ self.context_delimiter
+ self.answer_suffix
)
def _doc_to_choice(self, doc: dict) -> str:
if callable(self.doc_to_choice):
doc_to_choice = self.doc_to_choice(doc)
elif isinstance(self.doc_to_choice, str):
doc_to_choice = doc[self.doc_to_choice]
else:
doc_to_choice = self.doc_to_choice
return create_mc_choices(doc_to_choice, choice_delimiter=self.choice_delimiter)
def _doc_to_target(self, doc: dict) -> int:
"""Convert a document to target."""
if callable(self.doc_to_target):
return self.doc_to_target(doc)
elif isinstance(self.doc_to_target, str):
return doc[self.doc_to_target]
else:
return self.doc_to_target
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment