Commit a007bacb authored by Zhiwei Zhuang's avatar Zhiwei Zhuang
Browse files

fix two bugs when ran with qasper_bool and toxigen

parent ef332026
......@@ -2,11 +2,10 @@ import os
import re
import json
import fnmatch
import jsonlines
import argparse
import logging
from pathlib import Path
import numpy as np
from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger, SPACING
......@@ -15,6 +14,14 @@ from lm_eval.tasks import include_path
from typing import Union
def _handle_non_serializable(o):
if isinstance(o, np.int64):
return int(o)
elif isinstance(o, set):
return list(o)
raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")
def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
......@@ -103,6 +110,12 @@ def parse_eval_args() -> argparse.Namespace:
default="INFO",
help="Log error when tasks are not registered.",
)
parser.add_argument(
"--huggingface_token",
type=str,
default=None,
help="huggingface token for downloading some authorization datasets, like toxigen, https://huggingface.co/settings/tokens",
)
return parser.parse_args()
......@@ -119,7 +132,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
" --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.huggingface_token:
from huggingface_hub import login
login(token=args.huggingface_token)
if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}")
include_path(args.include_path)
......@@ -195,7 +211,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None:
if args.log_samples:
samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=lambda o: str(o))
dumped = json.dumps(results, indent=2, default=_handle_non_serializable)
if args.show_config:
print(dumped)
......@@ -210,9 +226,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
re.sub("/|=", "__", args.model_args), task_name
)
filename = path.joinpath(f"{output_name}.jsonl")
with jsonlines.open(filename, "w") as f:
f.write_all(samples[task_name])
samples_dumped = json.dumps(
samples[task_name], indent=2, default=_handle_non_serializable
)
filename.open("w").write(samples_dumped)
print(
f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
......
......@@ -15,7 +15,8 @@ from lm_eval.api.registry import (
import logging
eval_logger = logging.getLogger('lm-eval')
eval_logger = logging.getLogger("lm-eval")
def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment