Commit b7cfed19 authored by FarzanehNakhaee's avatar FarzanehNakhaee
Browse files

add example logger

parent db42dd03
import random import random
import itertools import itertools
import json
import collections import collections
import logging
import sys
import torch import torch
...@@ -22,6 +25,10 @@ from lm_eval.utils import ( ...@@ -22,6 +25,10 @@ from lm_eval.utils import (
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
@positional_deprecated @positional_deprecated
def simple_evaluate( def simple_evaluate(
...@@ -222,6 +229,7 @@ def evaluate( ...@@ -222,6 +229,7 @@ def evaluate(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
) )
) )
example_logger = logging.getLogger("examples")
for doc_id, doc in doc_iterator: for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
...@@ -229,6 +237,10 @@ def evaluate( ...@@ -229,6 +237,10 @@ def evaluate(
metrics = task.process_results( metrics = task.process_results(
doc, [req.filtered_resps[key] for req in requests] doc, [req.filtered_resps[key] for req in requests]
) )
target = task.doc_to_target(doc)
example = {"doc_id": doc_id, "doc": doc['text'], "target": target}
example.update(metrics)
example_logger.info(json.dumps(example))
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value) vals[(task_name, key, metric)].append(value)
......
...@@ -2,13 +2,14 @@ import os ...@@ -2,13 +2,14 @@ import os
import json import json
import fnmatch import fnmatch
import argparse import argparse
import logging
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.tasks import ALL_TASKS from lm_eval.tasks import ALL_TASKS
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
logger = logging.getLogger("main")
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices):
...@@ -56,9 +57,20 @@ def pattern_match(patterns, source_list): ...@@ -56,9 +57,20 @@ def pattern_match(patterns, source_list):
task_names.add(matching) task_names.add(matching)
return sorted(list(task_names)) return sorted(list(task_names))
def setup_example_logger(output_path, separator):
"""Sets up a logger that will save each example and prediction."""
example_logger = logging.getLogger("examples")
filename = f"./outputs/examples{separator}{output_path}.jsonl"
formatter = logging.Formatter("%(message)s")
handler = logging.FileHandler(filename)
handler.setFormatter(formatter)
example_logger.addHandler(handler)
example_logger.setLevel(logging.INFO)
def main(): def main():
args = parse_args() os.makedirs("./outputs", exist_ok=True)
args = parse_args()
if args.limit: if args.limit:
eval_logger.warning( eval_logger.warning(
...@@ -66,6 +78,10 @@ def main(): ...@@ -66,6 +78,10 @@ def main():
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
path_separator = "."
output_path = args.output_path if args.output_path is not None else ""
setup_example_logger(output_path, path_separator)
if args.tasks is not None: if args.tasks is not None:
if os.path.isdir(args.tasks): if os.path.isdir(args.tasks):
import glob import glob
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment