Commit 5fbc3f86 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fix random seed issue, log_samples optional

parent 09a71562
...@@ -8,6 +8,7 @@ import evaluate ...@@ -8,6 +8,7 @@ import evaluate
import random import random
import itertools import itertools
import functools import functools
from tqdm import tqdm
import datasets import datasets
import numpy as np import numpy as np
...@@ -217,8 +218,8 @@ class Task(abc.ABC): ...@@ -217,8 +218,8 @@ class Task(abc.ABC):
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler( self.sampler = samplers.Sampler(
list(self.fewshot_docs()), self, rnd=random.Random() list(self.fewshot_docs()), self, rnd=random.Random(1234)
) # TODO: pass the correct docs in here )
def download(self, data_dir=None, cache_dir=None, download_mode=None): def download(self, data_dir=None, cache_dir=None, download_mode=None):
"""Downloads and returns the task dataset. """Downloads and returns the task dataset.
...@@ -366,13 +367,18 @@ class Task(abc.ABC): ...@@ -366,13 +367,18 @@ class Task(abc.ABC):
False False
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" ), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info(
f"Building contexts for task '{self._config.task}' on rank {rank}..."
)
instances = [] instances = []
for doc_id, doc in utils.create_iterator( for doc_id, doc in utils.create_iterator(
enumerate(docs), rank, world_size, limit enumerate(docs), rank, world_size, limit
): ):
# sample fewshot context #TODO: need to offset doc_id by rank now! # sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, self._config.num_fewshot, rnd=random.Random() doc,
self._config.num_fewshot,
) )
# TODO: we should override self._config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self._config.repeats if doing greedy gen so users don't waste time+compute
...@@ -453,7 +459,7 @@ class Task(abc.ABC): ...@@ -453,7 +459,7 @@ class Task(abc.ABC):
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot, rnd=None): def fewshot_context(self, doc, num_fewshot):
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -461,15 +467,9 @@ class Task(abc.ABC): ...@@ -461,15 +467,9 @@ class Task(abc.ABC):
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int :param num_fewshot: int
The number of fewshot examples to provide in the returned context string. The number of fewshot examples to provide in the returned context string.
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
if num_fewshot == 0: if num_fewshot == 0:
# always prepend the (possibly empty) task description # always prepend the (possibly empty) task description
...@@ -625,7 +625,7 @@ class ConfigurableTask(Task): ...@@ -625,7 +625,7 @@ class ConfigurableTask(Task):
if self.fewshot_docs() is not None: if self.fewshot_docs() is not None:
self.sampler = samplers.Sampler( self.sampler = samplers.Sampler(
list(self.fewshot_docs()), self, rnd=random.Random() list(self.fewshot_docs()), self, rnd=random.Random(1234)
) )
def download(self, dataset_kwargs=None): def download(self, dataset_kwargs=None):
...@@ -1004,13 +1004,10 @@ class PerplexityTask(Task): ...@@ -1004,13 +1004,10 @@ class PerplexityTask(Task):
assert k == 0 assert k == 0
return [] return []
def fewshot_context(self, doc, num_fewshot, rnd=None): def fewshot_context(self, doc, num_fewshot):
assert ( assert (
num_fewshot == 0 num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks." ), "The number of fewshot examples must be 0 for perplexity tasks."
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`."
return "" return ""
......
...@@ -45,6 +45,7 @@ def simple_evaluate( ...@@ -45,6 +45,7 @@ def simple_evaluate(
check_integrity=False, check_integrity=False,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out=False, write_out=False,
log_samples=True,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -72,12 +73,17 @@ def simple_evaluate( ...@@ -72,12 +73,17 @@ def simple_evaluate(
:param check_integrity: bool :param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks Whether to run the relevant part of the test suite for the tasks
:param write_out: bool :param write_out: bool
If True, write details about prompts and logits to json for all tasks If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:return :return
Dictionary of results Dictionary of results
""" """
random.seed(1234) random.seed(0)
np.random.seed(1234) np.random.seed(1234)
torch.manual_seed(
1234
) # TODO: this may affect training runs that are run with evaluation mid-run.
assert tasks != [], "No tasks specified" assert tasks != [], "No tasks specified"
...@@ -118,6 +124,7 @@ def simple_evaluate( ...@@ -118,6 +124,7 @@ def simple_evaluate(
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
decontamination_ngrams_path=decontamination_ngrams_path, decontamination_ngrams_path=decontamination_ngrams_path,
write_out=write_out, write_out=write_out,
log_samples=log_samples,
) )
if lm.rank == 0: if lm.rank == 0:
...@@ -154,6 +161,7 @@ def evaluate( ...@@ -154,6 +161,7 @@ def evaluate(
bootstrap_iters=100000, bootstrap_iters=100000,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out=False, write_out=False,
log_samples=True,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -168,7 +176,9 @@ def evaluate( ...@@ -168,7 +176,9 @@ def evaluate(
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param write_out: bool :param write_out: bool
If True, write all prompts, logits and metrics to json for offline analysis If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:return :return
Dictionary of results Dictionary of results
""" """
...@@ -282,6 +292,7 @@ def evaluate( ...@@ -282,6 +292,7 @@ def evaluate(
metrics = task.process_results( metrics = task.process_results(
doc, [req.filtered_resps[key] for req in requests] doc, [req.filtered_resps[key] for req in requests]
) )
if log_samples:
target = task.doc_to_target(doc) target = task.doc_to_target(doc)
example = { example = {
"doc_id": doc_id, "doc_id": doc_id,
...@@ -366,7 +377,7 @@ def evaluate( ...@@ -366,7 +377,7 @@ def evaluate(
"results": dict(results), "results": dict(results),
"configs": dict(configs), "configs": dict(configs),
"versions": dict(versions), "versions": dict(versions),
"samples": samples, "samples": samples if log_samples else {},
} }
else: else:
......
...@@ -43,6 +43,7 @@ def parse_args(): ...@@ -43,6 +43,7 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--check_integrity", action="store_true") parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False) parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--log_samples", action="store_true", default=True)
return parser.parse_args() return parser.parse_args()
...@@ -89,9 +90,11 @@ def main(): ...@@ -89,9 +90,11 @@ def main():
decontamination_ngrams_path=args.decontamination_ngrams_path, decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
write_out=args.write_out, write_out=args.write_out,
log_samples=args.log_samples,
) )
if results is not None: if results is not None:
if args.log_samples:
samples = results.pop("samples") samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=lambda o: str(o)) dumped = json.dumps(results, indent=2, default=lambda o: str(o))
print(dumped) print(dumped)
...@@ -104,6 +107,7 @@ def main(): ...@@ -104,6 +107,7 @@ def main():
with open(args.output_path, "w") as f: with open(args.output_path, "w") as f:
f.write(dumped) f.write(dumped)
if args.log_samples:
for task_name, config in results["configs"].items(): for task_name, config in results["configs"].items():
output_name = "{}_{}".format( output_name = "{}_{}".format(
re.sub("/", "__", args.model_args), task_name re.sub("/", "__", args.model_args), task_name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment