Unverified Commit 4e94af6f authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #522 from EleutherAI/fix-mgpt-fewshot

parents 84ef60ee 25699d3e
......@@ -26,6 +26,7 @@ Homepage: https://github.com/google-research-datasets/paws/tree/master/pawsx
"""
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
from lm_eval import utils
_CITATION = """
@inproceedings{yang-etal-2019-paws,
......@@ -85,6 +86,11 @@ class PAWSXBase(Task):
def doc_to_target(self, doc):
return " " + [self.YES, self.NO][doc["label"]]
def doc_to_fewshot_prompt(self, doc):
prompt = self.doc_to_text(doc)
return prompt.replace("[MASK]", self.doc_to_target(doc)[1:])
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -136,6 +142,76 @@ class PAWSXBase(Task):
def higher_is_better(self):
return {"acc": True}
@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
if num_fewshot == 0:
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
"\n\n".join(
[
# self.doc_to_text(doc) + self.doc_to_target(doc)
self.doc_to_fewshot_prompt(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
return description + labeled_examples + example
class PAWSX_en(PAWSXBase):
DATASET_NAME = "en"
......
......@@ -18,6 +18,7 @@ Homepage: https://github.com/facebookresearch/XNLI
import numpy as np
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
from lm_eval import utils
_CITATIONS = """
@InProceedings{conneau2018xnli,
......@@ -89,6 +90,11 @@ class XNLIBase(Task):
]
)
def doc_to_fewshot_prompt(self, doc):
prompt = self.doc_to_text(doc)
return prompt.replace("[MASK]", self.doc_to_target(doc)[1:])
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -138,6 +144,76 @@ class XNLIBase(Task):
"""
return {"acc": True}
@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
if num_fewshot == 0:
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
"\n\n".join(
[
# self.doc_to_text(doc) + self.doc_to_target(doc)
self.doc_to_fewshot_prompt(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
return description + labeled_examples + example
class XNLI_en(XNLIBase): # English
DATASET_NAME = "en"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment