Commit 578f5d48 authored by lintangsutawika's avatar lintangsutawika
Browse files

add custom fewshot doc_to_text, target, and choice

parent 42d194f8
import datasets import datasets
from functools import partial
class ContextSampler: class ContextSampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None: def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
...@@ -16,12 +16,28 @@ class ContextSampler: ...@@ -16,12 +16,28 @@ class ContextSampler:
self.fewshot_delimiter = self.config.fewshot_delimiter self.fewshot_delimiter = self.config.fewshot_delimiter
if self.config.fewshot_config is not None and self.config.fewshot_config.get("doc_to_text", None) is not None: if self.config.fewshot_config is not None and self.config.fewshot_config.get("doc_to_text", None) is not None:
self.doc_to_text = self.config.fewshot_config.get("doc_to_text", None) self.doc_to_text = partial(
self.task.doc_to_text,
doc_to_text=self.config.fewshot_config.get("doc_to_text", None)
)
else: else:
self.doc_to_text = self.task.doc_to_text self.doc_to_text = self.task.doc_to_text
self.doc_to_target = self.task.doc_to_target if self.config.fewshot_config is not None and self.config.fewshot_config.get("doc_to_target", None) is not None:
self.doc_to_choice = self.task.doc_to_choice self.doc_to_target = partial(
self.task.doc_to_target,
doc_to_target=self.config.fewshot_config.get("doc_to_target", None)
)
else:
self.doc_to_target = self.task.doc_to_target
if self.config.fewshot_config is not None and self.config.fewshot_config.get("doc_to_choice", None) is not None:
self.doc_to_choice = partial(
self.task.doc_to_choice,
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None)
)
else:
self.doc_to_choice = self.task.doc_to_choice
self.docs = docs # HF dataset split, provided by task._fewshot_docs() self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from if fewshot_indices: # subset few-shot docs from
...@@ -56,14 +72,15 @@ class ContextSampler: ...@@ -56,14 +72,15 @@ class ContextSampler:
else self.doc_to_choice(doc)[doc_content] else self.doc_to_choice(doc)[doc_content]
) )
labeled_examples += self.target_delimiter labeled_examples += self.target_delimiter
labeled_examples += ( if doc_target is not "":
str(doc_target[0]) labeled_examples += (
if isinstance(doc_target, list) str(doc_target[0])
else doc_target if isinstance(doc_target, list)
if self.config.doc_to_choice is None or isinstance(doc_target, str) else doc_target
else str(self.doc_to_choice(doc)[doc_target]) if self.config.doc_to_choice is None or isinstance(doc_target, str)
) else str(self.doc_to_choice(doc)[doc_target])
labeled_examples += self.fewshot_delimiter )
labeled_examples += self.fewshot_delimiter
return labeled_examples return labeled_examples
......
...@@ -1158,9 +1158,11 @@ class ConfigurableTask(Task): ...@@ -1158,9 +1158,11 @@ class ConfigurableTask(Task):
""" """
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc, doc_to_text=None):
if self.prompt is not None: if self.prompt is not None:
doc_to_text = self.prompt doc_to_text = self.prompt
elif doc_to_text is not None:
doc_to_text = doc_to_text
else: else:
doc_to_text = self.config.doc_to_text doc_to_text = self.config.doc_to_text
...@@ -1192,9 +1194,11 @@ class ConfigurableTask(Task): ...@@ -1192,9 +1194,11 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: Mapping) -> Union[int, str, list]: def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
elif doc_to_target is not None:
doc_to_target = doc_to_target
else: else:
doc_to_target = self.config.doc_to_target doc_to_target = self.config.doc_to_target
...@@ -1236,9 +1240,11 @@ class ConfigurableTask(Task): ...@@ -1236,9 +1240,11 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_choice(self, doc: Any) -> List[str]: def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_choice = self.prompt doc_to_choice = self.prompt
elif doc_to_choice is not None:
doc_to_choice = doc_to_choice
elif self.config.doc_to_choice is None: elif self.config.doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config") eval_logger.error("doc_to_choice was called but not set in config")
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment