Unverified Commit aff47420 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #934 from AndyWolfZwei/andy/big-refactor-bugfixed

[Refactor]fix two bugs when ran with qasper_bool and toxigen
parents 06faed0c 4ac7f064
......@@ -2,11 +2,10 @@ import os
import re
import json
import fnmatch
import jsonlines
import argparse
import logging
from pathlib import Path
import numpy as np
from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger, SPACING
......@@ -15,6 +14,15 @@ from lm_eval.tasks import include_path
from typing import Union
def _handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32):
return int(o)
elif isinstance(o, set):
return list(o)
else:
return str(o)
def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
......@@ -119,7 +127,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
" --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}")
include_path(args.include_path)
......@@ -195,7 +202,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None:
if args.log_samples:
samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=lambda o: str(o))
dumped = json.dumps(results, indent=2, default=_handle_non_serializable)
if args.show_config:
print(dumped)
......@@ -210,9 +217,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
re.sub("/|=", "__", args.model_args), task_name
)
filename = path.joinpath(f"{output_name}.jsonl")
with jsonlines.open(filename, "w") as f:
f.write_all(samples[task_name])
samples_dumped = json.dumps(
samples[task_name], indent=2, default=_handle_non_serializable
)
filename.open("w").write(samples_dumped)
print(
f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment