Unverified Commit 3fc5bedc authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #563 from farzanehnakhaee70/big-refactor

Logging Samples
parents 5fa87056 5b6847be
import random import random
import itertools import itertools
import json
import collections import collections
import logging
import sys
import torch import torch
...@@ -22,6 +25,10 @@ from lm_eval.utils import ( ...@@ -22,6 +25,10 @@ from lm_eval.utils import (
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
@positional_deprecated @positional_deprecated
def simple_evaluate( def simple_evaluate(
...@@ -142,7 +149,7 @@ def evaluate( ...@@ -142,7 +149,7 @@ def evaluate(
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
versions = collections.defaultdict(dict) versions = collections.defaultdict(dict)
configs = collections.defaultdict(dict) configs = collections.defaultdict(dict)
samples = collections.defaultdict(list)
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
# docs = {} # docs = {}
...@@ -225,6 +232,7 @@ def evaluate( ...@@ -225,6 +232,7 @@ def evaluate(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
) )
) )
for doc_id, doc in doc_iterator: for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
...@@ -232,6 +240,16 @@ def evaluate( ...@@ -232,6 +240,16 @@ def evaluate(
metrics = task.process_results( metrics = task.process_results(
doc, [req.filtered_resps[key] for req in requests] doc, [req.filtered_resps[key] for req in requests]
) )
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id,
"doc": doc,
"target": target,
"resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[key] for req in requests],
}
example.update(metrics)
samples[task_name].append(example)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value) vals[(task_name, key, metric)].append(value)
...@@ -296,6 +314,7 @@ def evaluate( ...@@ -296,6 +314,7 @@ def evaluate(
"results": dict(results), "results": dict(results),
"configs": dict(configs), "configs": dict(configs),
"versions": dict(versions), "versions": dict(versions),
"samples": samples,
} }
else: else:
......
...@@ -16,6 +16,6 @@ metric_list: ...@@ -16,6 +16,6 @@ metric_list:
- metric: perplexity - metric: perplexity
aggregation: perplexity aggregation: perplexity
higher_is_better: false higher_is_better: false
- metric: accuracy - metric: acc
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
import os import os
import re
import json import json
import fnmatch import fnmatch
import jsonlines
import argparse import argparse
import logging
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS from lm_eval.api.registry import ALL_TASKS
...@@ -98,13 +101,31 @@ def main(): ...@@ -98,13 +101,31 @@ def main():
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
) )
if results is not None: if results is not None:
samples = results.pop("samples")
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
if args.output_path: if args.output_path:
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
with open(args.output_path, "w") as f: with open(args.output_path, "w") as f:
f.write(dumped) f.write(dumped)
for task_name, config in results["configs"].items():
output_name = "{}_{}".format(
re.sub("/", "__", args.model_args), task_name
)
if os.path.isdir(args.output_path):
filename = f"./{args.output_path}/{output_name}.jsonl"
elif os.path.isfile(args.output_path):
filename = (
f"./{os.path.dirname(args.output_path)}/{output_name}.jsonl"
)
with jsonlines.open(filename, "w") as f:
f.write_all(samples[task_name])
print( print(
f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, " f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}" f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment