Unverified Commit 3fc5bedc authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #563 from farzanehnakhaee70/big-refactor

Logging Samples
parents 5fa87056 5b6847be
import random
import itertools
import json
import collections
import logging
import sys
import torch
......@@ -22,6 +25,10 @@ from lm_eval.utils import (
from lm_eval.logger import eval_logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
@positional_deprecated
def simple_evaluate(
......@@ -142,7 +149,7 @@ def evaluate(
results = collections.defaultdict(dict)
versions = collections.defaultdict(dict)
configs = collections.defaultdict(dict)
samples = collections.defaultdict(list)
requests = collections.defaultdict(list)
# docs = {}
......@@ -225,6 +232,7 @@ def evaluate(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
)
)
for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
......@@ -232,6 +240,16 @@ def evaluate(
metrics = task.process_results(
doc, [req.filtered_resps[key] for req in requests]
)
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id,
"doc": doc,
"target": target,
"resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[key] for req in requests],
}
example.update(metrics)
samples[task_name].append(example)
for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value)
......@@ -296,6 +314,7 @@ def evaluate(
"results": dict(results),
"configs": dict(configs),
"versions": dict(versions),
"samples": samples,
}
else:
......
......@@ -16,6 +16,6 @@ metric_list:
- metric: perplexity
aggregation: perplexity
higher_is_better: false
- metric: accuracy
- metric: acc
aggregation: mean
higher_is_better: true
import os
import re
import json
import fnmatch
import jsonlines
import argparse
import logging
from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS
......@@ -98,13 +101,31 @@ def main():
check_integrity=args.check_integrity,
)
if results is not None:
samples = results.pop("samples")
dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path:
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
with open(args.output_path, "w") as f:
f.write(dumped)
for task_name, config in results["configs"].items():
output_name = "{}_{}".format(
re.sub("/", "__", args.model_args), task_name
)
if os.path.isdir(args.output_path):
filename = f"./{args.output_path}/{output_name}.jsonl"
elif os.path.isfile(args.output_path):
filename = (
f"./{os.path.dirname(args.output_path)}/{output_name}.jsonl"
)
with jsonlines.open(filename, "w") as f:
f.write_all(samples[task_name])
print(
f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment