Commit a007bacb authored by Zhiwei Zhuang's avatar Zhiwei Zhuang
Browse files

fix two bugs when ran with qasper_bool and toxigen

parent ef332026
...@@ -2,11 +2,10 @@ import os ...@@ -2,11 +2,10 @@ import os
import re import re
import json import json
import fnmatch import fnmatch
import jsonlines
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import numpy as np
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger, SPACING from lm_eval.logger import eval_logger, SPACING
...@@ -15,6 +14,14 @@ from lm_eval.tasks import include_path ...@@ -15,6 +14,14 @@ from lm_eval.tasks import include_path
from typing import Union from typing import Union
def _handle_non_serializable(o):
if isinstance(o, np.int64):
return int(o)
elif isinstance(o, set):
return list(o)
raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")
def parse_eval_args() -> argparse.Namespace: def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`") parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
...@@ -103,6 +110,12 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -103,6 +110,12 @@ def parse_eval_args() -> argparse.Namespace:
default="INFO", default="INFO",
help="Log error when tasks are not registered.", help="Log error when tasks are not registered.",
) )
parser.add_argument(
"--huggingface_token",
type=str,
default=None,
help="huggingface token for downloading some authorization datasets, like toxigen, https://huggingface.co/settings/tokens",
)
return parser.parse_args() return parser.parse_args()
...@@ -119,7 +132,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -119,7 +132,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
" --limit SHOULD ONLY BE USED FOR TESTING." " --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
if args.huggingface_token:
from huggingface_hub import login
login(token=args.huggingface_token)
if args.include_path is not None: if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}") eval_logger.info(f"Including path: {args.include_path}")
include_path(args.include_path) include_path(args.include_path)
...@@ -195,7 +211,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -195,7 +211,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None: if results is not None:
if args.log_samples: if args.log_samples:
samples = results.pop("samples") samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=lambda o: str(o)) dumped = json.dumps(results, indent=2, default=_handle_non_serializable)
if args.show_config: if args.show_config:
print(dumped) print(dumped)
...@@ -210,9 +226,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -210,9 +226,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
re.sub("/|=", "__", args.model_args), task_name re.sub("/|=", "__", args.model_args), task_name
) )
filename = path.joinpath(f"{output_name}.jsonl") filename = path.joinpath(f"{output_name}.jsonl")
samples_dumped = json.dumps(
with jsonlines.open(filename, "w") as f: samples[task_name], indent=2, default=_handle_non_serializable
f.write_all(samples[task_name]) )
filename.open("w").write(samples_dumped)
print( print(
f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
......
...@@ -15,7 +15,8 @@ from lm_eval.api.registry import ( ...@@ -15,7 +15,8 @@ from lm_eval.api.registry import (
import logging import logging
eval_logger = logging.getLogger('lm-eval') eval_logger = logging.getLogger("lm-eval")
def register_configurable_task(config: Dict[str, str]) -> int: def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type( SubClass = type(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment