Commit b7cfed19 authored by FarzanehNakhaee's avatar FarzanehNakhaee
Browse files

add example logger

parent db42dd03
import random
import itertools
import json
import collections
import logging
import sys
import torch
......@@ -22,6 +25,10 @@ from lm_eval.utils import (
from lm_eval.logger import eval_logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
@positional_deprecated
def simple_evaluate(
......@@ -222,6 +229,7 @@ def evaluate(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
)
)
example_logger = logging.getLogger("examples")
for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
......@@ -229,6 +237,10 @@ def evaluate(
metrics = task.process_results(
doc, [req.filtered_resps[key] for req in requests]
)
target = task.doc_to_target(doc)
example = {"doc_id": doc_id, "doc": doc['text'], "target": target}
example.update(metrics)
example_logger.info(json.dumps(example))
for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value)
......
......@@ -2,13 +2,14 @@ import os
import json
import fnmatch
import argparse
import logging
from lm_eval import evaluator, utils
from lm_eval.tasks import ALL_TASKS
from lm_eval.logger import eval_logger
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logger = logging.getLogger("main")
class MultiChoice:
def __init__(self, choices):
......@@ -56,9 +57,20 @@ def pattern_match(patterns, source_list):
task_names.add(matching)
return sorted(list(task_names))
def setup_example_logger(output_path, separator):
"""Sets up a logger that will save each example and prediction."""
example_logger = logging.getLogger("examples")
filename = f"./outputs/examples{separator}{output_path}.jsonl"
formatter = logging.Formatter("%(message)s")
handler = logging.FileHandler(filename)
handler.setFormatter(formatter)
example_logger.addHandler(handler)
example_logger.setLevel(logging.INFO)
def main():
args = parse_args()
os.makedirs("./outputs", exist_ok=True)
args = parse_args()
if args.limit:
eval_logger.warning(
......@@ -66,6 +78,10 @@ def main():
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
path_separator = "."
output_path = args.output_path if args.output_path is not None else ""
setup_example_logger(output_path, path_separator)
if args.tasks is not None:
if os.path.isdir(args.tasks):
import glob
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment