Commit 3b2d5f6c authored by Leo Gao's avatar Leo Gao
Browse files

Pass random state around to fewshot_examples and fewshot_context

parent 35f58b95
......@@ -115,11 +115,10 @@ class Task(abc.ABC):
"""
return []
def fewshot_examples(self, k):
def fewshot_examples(self, k, rnd):
if self._training_docs is None:
self._training_docs = list(self.training_docs())
rnd = random.Random()
rnd.seed(42)
return rnd.sample(self._training_docs, k)
@abc.abstractmethod
......@@ -178,18 +177,17 @@ class Task(abc.ABC):
def fewshot_description(self):
return ""
def fewshot_context(self, doc, num_fewshot, provide_description):
def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
raw_description = self.fewshot_description()
description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot)
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs else self.test_docs())
rnd = random.Random()
rnd.seed(42)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
......
......@@ -43,6 +43,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
doc=doc,
provide_description=provide_description,
num_fewshot=num_fewshot,
rnd=rnd
)
reqs = task.construct_requests(doc, ctx)
......
......@@ -218,7 +218,7 @@ class EthicsUtilitarianismOriginal(Ethics):
def fewshot_description(self):
return "Rate how pleasant each of the following activities is on a scale from 1 (very unpleasant) to 10 (very pleasant).\n\n"
def fewshot_examples(self, k):
def fewshot_examples(self, k, rnd):
# Overwriting fewshot examples as k can be max 5
assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more."
# These prompts implicitly (under)specify the task utility function
......
......@@ -99,13 +99,16 @@ class GeneralHendrycksTest(MultipleChoiceTask):
filename = self.DATASET_PATH / "test" / f"{self.subject}_test.csv"
return self._load_docs(filename)
def fewshot_examples(self, k):
def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't
filename = self.DATASET_PATH / "dev" / f"{self.subject}_dev.csv"
rnd = random.Random()
rnd.seed(42)
return rnd.sample(list(self._load_docs(filename)), k)
if self._fewshot_docs is None:
self._fewshot_docs = list(self._load_docs(filename))
return rnd.sample(list(self._fewshot_docs), k)
def fewshot_description(self):
subject = self.subject.replace("_", " ")
......
......@@ -28,13 +28,11 @@ class NaturalQs(HFTask):
# Data is too large to fit in memory.
return self.data["train"]
def fewshot_examples(self, k):
def fewshot_examples(self, k, rnd):
# Data is too large to fit in memory. We just sample from the first bit.
if self._training_docs is None:
self._training_docs = list(islice(self.training_docs(), 0, 100000))
rnd = random.Random()
rnd.seed(42)
return rnd.sample(self._training_docs, k)
def doc_to_text(self, doc):
......
......@@ -56,12 +56,14 @@ class WinogradSchemaChallenge273(HFTask):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def fewshot_examples(self, k):
def fewshot_examples(self, k, rnd):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
rnd = random.Random()
rnd.seed(42)
return rnd.sample(list(self.test_docs()), k)
if self._fewshot_docs is None:
self._fewshot_docs = list(self.test_docs())
return rnd.sample(list(self._fewshot_docs), k)
def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment