Unverified Commit aff47420 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #934 from AndyWolfZwei/andy/big-refactor-bugfixed

[Refactor]fix two bugs when ran with qasper_bool and toxigen
parents 06faed0c 4ac7f064
...@@ -2,11 +2,10 @@ import os ...@@ -2,11 +2,10 @@ import os
import re import re
import json import json
import fnmatch import fnmatch
import jsonlines
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import numpy as np
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger, SPACING from lm_eval.logger import eval_logger, SPACING
...@@ -15,6 +14,15 @@ from lm_eval.tasks import include_path ...@@ -15,6 +14,15 @@ from lm_eval.tasks import include_path
from typing import Union from typing import Union
def _handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32):
return int(o)
elif isinstance(o, set):
return list(o)
else:
return str(o)
def parse_eval_args() -> argparse.Namespace: def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`") parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
...@@ -119,7 +127,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -119,7 +127,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
" --limit SHOULD ONLY BE USED FOR TESTING." " --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
if args.include_path is not None: if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}") eval_logger.info(f"Including path: {args.include_path}")
include_path(args.include_path) include_path(args.include_path)
...@@ -195,7 +202,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -195,7 +202,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None: if results is not None:
if args.log_samples: if args.log_samples:
samples = results.pop("samples") samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=lambda o: str(o)) dumped = json.dumps(results, indent=2, default=_handle_non_serializable)
if args.show_config: if args.show_config:
print(dumped) print(dumped)
...@@ -210,9 +217,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -210,9 +217,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
re.sub("/|=", "__", args.model_args), task_name re.sub("/|=", "__", args.model_args), task_name
) )
filename = path.joinpath(f"{output_name}.jsonl") filename = path.joinpath(f"{output_name}.jsonl")
samples_dumped = json.dumps(
with jsonlines.open(filename, "w") as f: samples[task_name], indent=2, default=_handle_non_serializable
f.write_all(samples[task_name]) )
filename.open("w").write(samples_dumped)
print( print(
f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment