Commit 3b2d5f6c authored by Leo Gao's avatar Leo Gao
Browse files

Pass random state around to fewshot_examples and fewshot_context

parent 35f58b95
...@@ -115,11 +115,10 @@ class Task(abc.ABC): ...@@ -115,11 +115,10 @@ class Task(abc.ABC):
""" """
return [] return []
def fewshot_examples(self, k): def fewshot_examples(self, k, rnd):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.training_docs()) self._training_docs = list(self.training_docs())
rnd = random.Random()
rnd.seed(42)
return rnd.sample(self._training_docs, k) return rnd.sample(self._training_docs, k)
@abc.abstractmethod @abc.abstractmethod
...@@ -178,18 +177,17 @@ class Task(abc.ABC): ...@@ -178,18 +177,17 @@ class Task(abc.ABC):
def fewshot_description(self): def fewshot_description(self):
return "" return ""
def fewshot_context(self, doc, num_fewshot, provide_description): def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
raw_description = self.fewshot_description() raw_description = self.fewshot_description()
description = (raw_description + "\n===\n\n") if provide_description and raw_description else "" description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
# for sets with no training docs, draw from other set *but ensure no overlap with current doc* # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs(): if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot) fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else: else:
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs else self.test_docs()) self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs else self.test_docs())
rnd = random.Random()
rnd.seed(42)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot # get rid of the doc that's the one we're evaluating, if it's in the fewshot
......
...@@ -43,6 +43,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -43,6 +43,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
doc=doc, doc=doc,
provide_description=provide_description, provide_description=provide_description,
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
rnd=rnd
) )
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
......
...@@ -218,7 +218,7 @@ class EthicsUtilitarianismOriginal(Ethics): ...@@ -218,7 +218,7 @@ class EthicsUtilitarianismOriginal(Ethics):
def fewshot_description(self): def fewshot_description(self):
return "Rate how pleasant each of the following activities is on a scale from 1 (very unpleasant) to 10 (very pleasant).\n\n" return "Rate how pleasant each of the following activities is on a scale from 1 (very unpleasant) to 10 (very pleasant).\n\n"
def fewshot_examples(self, k): def fewshot_examples(self, k, rnd):
# Overwriting fewshot examples as k can be max 5 # Overwriting fewshot examples as k can be max 5
assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more." assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more."
# These prompts implicitly (under)specify the task utility function # These prompts implicitly (under)specify the task utility function
......
...@@ -99,13 +99,16 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -99,13 +99,16 @@ class GeneralHendrycksTest(MultipleChoiceTask):
filename = self.DATASET_PATH / "test" / f"{self.subject}_test.csv" filename = self.DATASET_PATH / "test" / f"{self.subject}_test.csv"
return self._load_docs(filename) return self._load_docs(filename)
def fewshot_examples(self, k): def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is # fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't # in the same distribution as val/test but auxiliary_train isn't
filename = self.DATASET_PATH / "dev" / f"{self.subject}_dev.csv" filename = self.DATASET_PATH / "dev" / f"{self.subject}_dev.csv"
rnd = random.Random()
rnd.seed(42) if self._fewshot_docs is None:
return rnd.sample(list(self._load_docs(filename)), k) self._fewshot_docs = list(self._load_docs(filename))
return rnd.sample(list(self._fewshot_docs), k)
def fewshot_description(self): def fewshot_description(self):
subject = self.subject.replace("_", " ") subject = self.subject.replace("_", " ")
......
...@@ -28,13 +28,11 @@ class NaturalQs(HFTask): ...@@ -28,13 +28,11 @@ class NaturalQs(HFTask):
# Data is too large to fit in memory. # Data is too large to fit in memory.
return self.data["train"] return self.data["train"]
def fewshot_examples(self, k): def fewshot_examples(self, k, rnd):
# Data is too large to fit in memory. We just sample from the first bit. # Data is too large to fit in memory. We just sample from the first bit.
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(islice(self.training_docs(), 0, 100000)) self._training_docs = list(islice(self.training_docs(), 0, 100000))
rnd = random.Random()
rnd.seed(42)
return rnd.sample(self._training_docs, k) return rnd.sample(self._training_docs, k)
def doc_to_text(self, doc): def doc_to_text(self, doc):
......
...@@ -56,12 +56,14 @@ class WinogradSchemaChallenge273(HFTask): ...@@ -56,12 +56,14 @@ class WinogradSchemaChallenge273(HFTask):
# TODO: redo description # TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False." return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def fewshot_examples(self, k): def fewshot_examples(self, k, rnd):
# NOTE: `super().fewshot_examples` samples from training docs which are # NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset. # not available for this test-set-only dataset.
rnd = random.Random()
rnd.seed(42) if self._fewshot_docs is None:
return rnd.sample(list(self.test_docs()), k) self._fewshot_docs = list(self.test_docs())
return rnd.sample(list(self._fewshot_docs), k)
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]]) return self.partial_context(doc, doc["options"][doc["label"]])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment