Commit 5fbc3f86 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fix random seed issue, log_samples optional

parent 09a71562
......@@ -8,6 +8,7 @@ import evaluate
import random
import itertools
import functools
from tqdm import tqdm
import datasets
import numpy as np
......@@ -217,8 +218,8 @@ class Task(abc.ABC):
self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler(
list(self.fewshot_docs()), self, rnd=random.Random()
) # TODO: pass the correct docs in here
list(self.fewshot_docs()), self, rnd=random.Random(1234)
)
def download(self, data_dir=None, cache_dir=None, download_mode=None):
"""Downloads and returns the task dataset.
......@@ -366,13 +367,18 @@ class Task(abc.ABC):
False
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info(
f"Building contexts for task '{self._config.task}' on rank {rank}..."
)
instances = []
for doc_id, doc in utils.create_iterator(
enumerate(docs), rank, world_size, limit
):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(
doc, self._config.num_fewshot, rnd=random.Random()
doc,
self._config.num_fewshot,
)
# TODO: we should override self._config.repeats if doing greedy gen so users don't waste time+compute
......@@ -453,7 +459,7 @@ class Task(abc.ABC):
return len(re.split(r"\s+", doc))
@utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot, rnd=None):
def fewshot_context(self, doc, num_fewshot):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -461,15 +467,9 @@ class Task(abc.ABC):
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
if num_fewshot == 0:
# always prepend the (possibly empty) task description
......@@ -625,7 +625,7 @@ class ConfigurableTask(Task):
if self.fewshot_docs() is not None:
self.sampler = samplers.Sampler(
list(self.fewshot_docs()), self, rnd=random.Random()
list(self.fewshot_docs()), self, rnd=random.Random(1234)
)
def download(self, dataset_kwargs=None):
......@@ -1004,13 +1004,10 @@ class PerplexityTask(Task):
assert k == 0
return []
def fewshot_context(self, doc, num_fewshot, rnd=None):
def fewshot_context(self, doc, num_fewshot):
assert (
num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks."
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`."
return ""
......
......@@ -45,6 +45,7 @@ def simple_evaluate(
check_integrity=False,
decontamination_ngrams_path=None,
write_out=False,
log_samples=True,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -72,12 +73,17 @@ def simple_evaluate(
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:param write_out: bool
If True, write details about prompts and logits to json for all tasks
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:return
Dictionary of results
"""
random.seed(1234)
random.seed(0)
np.random.seed(1234)
torch.manual_seed(
1234
) # TODO: this may affect training runs that are run with evaluation mid-run.
assert tasks != [], "No tasks specified"
......@@ -118,6 +124,7 @@ def simple_evaluate(
bootstrap_iters=bootstrap_iters,
decontamination_ngrams_path=decontamination_ngrams_path,
write_out=write_out,
log_samples=log_samples,
)
if lm.rank == 0:
......@@ -154,6 +161,7 @@ def evaluate(
bootstrap_iters=100000,
decontamination_ngrams_path=None,
write_out=False,
log_samples=True,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -168,7 +176,9 @@ def evaluate(
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param write_out: bool
If True, write all prompts, logits and metrics to json for offline analysis
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:return
Dictionary of results
"""
......@@ -282,17 +292,18 @@ def evaluate(
metrics = task.process_results(
doc, [req.filtered_resps[key] for req in requests]
)
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[key] for req in requests],
}
example.update(metrics)
samples[task_name].append(example)
if log_samples:
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[key] for req in requests],
}
example.update(metrics)
samples[task_name].append(example)
for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value)
......@@ -366,7 +377,7 @@ def evaluate(
"results": dict(results),
"configs": dict(configs),
"versions": dict(versions),
"samples": samples,
"samples": samples if log_samples else {},
}
else:
......
......@@ -43,6 +43,7 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--log_samples", action="store_true", default=True)
return parser.parse_args()
......@@ -89,10 +90,12 @@ def main():
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_out=args.write_out,
log_samples=args.log_samples,
)
if results is not None:
samples = results.pop("samples")
if args.log_samples:
samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=lambda o: str(o))
print(dumped)
......@@ -104,19 +107,20 @@ def main():
with open(args.output_path, "w") as f:
f.write(dumped)
for task_name, config in results["configs"].items():
output_name = "{}_{}".format(
re.sub("/", "__", args.model_args), task_name
)
if os.path.isdir(args.output_path):
filename = f"./{args.output_path}/{output_name}.jsonl"
elif os.path.isfile(args.output_path):
filename = (
f"./{os.path.dirname(args.output_path)}/{output_name}.jsonl"
if args.log_samples:
for task_name, config in results["configs"].items():
output_name = "{}_{}".format(
re.sub("/", "__", args.model_args), task_name
)
with jsonlines.open(filename, "w") as f:
f.write_all(samples[task_name])
if os.path.isdir(args.output_path):
filename = f"./{args.output_path}/{output_name}.jsonl"
elif os.path.isfile(args.output_path):
filename = (
f"./{os.path.dirname(args.output_path)}/{output_name}.jsonl"
)
with jsonlines.open(filename, "w") as f:
f.write_all(samples[task_name])
print(
f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment