Commit 5a8ac198 authored by lintangsutawika's avatar lintangsutawika
Browse files

re implemented fewshot_context method in the class to allow custom prompt for fewshot

parent 84ef60ee
...@@ -26,6 +26,7 @@ Homepage: https://github.com/google-research-datasets/paws/tree/master/pawsx ...@@ -26,6 +26,7 @@ Homepage: https://github.com/google-research-datasets/paws/tree/master/pawsx
""" """
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
from lm_eval import utils
_CITATION = """ _CITATION = """
@inproceedings{yang-etal-2019-paws, @inproceedings{yang-etal-2019-paws,
...@@ -85,6 +86,11 @@ class PAWSXBase(Task): ...@@ -85,6 +86,11 @@ class PAWSXBase(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + [self.YES, self.NO][doc["label"]] return " " + [self.YES, self.NO][doc["label"]]
def doc_to_fewshot_prompt(self, doc):
prompt = self.doc_to_text(doc)
return prompt.replace("[MASK]", self.doc_to_target(doc)[1:])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -136,6 +142,76 @@ class PAWSXBase(Task): ...@@ -136,6 +142,76 @@ class PAWSXBase(Task):
def higher_is_better(self): def higher_is_better(self):
return {"acc": True} return {"acc": True}
@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
if num_fewshot == 0:
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
"\n\n".join(
[
# self.doc_to_text(doc) + self.doc_to_target(doc)
self.doc_to_fewshot_prompt(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
return description + labeled_examples + example
class PAWSX_en(PAWSXBase): class PAWSX_en(PAWSXBase):
DATASET_NAME = "en" DATASET_NAME = "en"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment