Unverified Commit 4e94af6f authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #522 from EleutherAI/fix-mgpt-fewshot

parents 84ef60ee 25699d3e
...@@ -26,6 +26,7 @@ Homepage: https://github.com/google-research-datasets/paws/tree/master/pawsx ...@@ -26,6 +26,7 @@ Homepage: https://github.com/google-research-datasets/paws/tree/master/pawsx
""" """
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
from lm_eval import utils
_CITATION = """ _CITATION = """
@inproceedings{yang-etal-2019-paws, @inproceedings{yang-etal-2019-paws,
...@@ -85,6 +86,11 @@ class PAWSXBase(Task): ...@@ -85,6 +86,11 @@ class PAWSXBase(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + [self.YES, self.NO][doc["label"]] return " " + [self.YES, self.NO][doc["label"]]
def doc_to_fewshot_prompt(self, doc):
prompt = self.doc_to_text(doc)
return prompt.replace("[MASK]", self.doc_to_target(doc)[1:])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -136,6 +142,76 @@ class PAWSXBase(Task): ...@@ -136,6 +142,76 @@ class PAWSXBase(Task):
def higher_is_better(self): def higher_is_better(self):
return {"acc": True} return {"acc": True}
@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
if num_fewshot == 0:
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
"\n\n".join(
[
# self.doc_to_text(doc) + self.doc_to_target(doc)
self.doc_to_fewshot_prompt(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
return description + labeled_examples + example
class PAWSX_en(PAWSXBase): class PAWSX_en(PAWSXBase):
DATASET_NAME = "en" DATASET_NAME = "en"
......
...@@ -18,6 +18,7 @@ Homepage: https://github.com/facebookresearch/XNLI ...@@ -18,6 +18,7 @@ Homepage: https://github.com/facebookresearch/XNLI
import numpy as np import numpy as np
from lm_eval.base import rf, Task from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
from lm_eval import utils
_CITATIONS = """ _CITATIONS = """
@InProceedings{conneau2018xnli, @InProceedings{conneau2018xnli,
...@@ -89,6 +90,11 @@ class XNLIBase(Task): ...@@ -89,6 +90,11 @@ class XNLIBase(Task):
] ]
) )
def doc_to_fewshot_prompt(self, doc):
prompt = self.doc_to_text(doc)
return prompt.replace("[MASK]", self.doc_to_target(doc)[1:])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -138,6 +144,76 @@ class XNLIBase(Task): ...@@ -138,6 +144,76 @@ class XNLIBase(Task):
""" """
return {"acc": True} return {"acc": True}
@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
if num_fewshot == 0:
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
"\n\n".join(
[
# self.doc_to_text(doc) + self.doc_to_target(doc)
self.doc_to_fewshot_prompt(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
return description + labeled_examples + example
class XNLI_en(XNLIBase): # English class XNLI_en(XNLIBase): # English
DATASET_NAME = "en" DATASET_NAME = "en"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment