Commit 6c80d52a authored by nikuya3's avatar nikuya3
Browse files

Merge branch 'hellaswag' of github.com:nikuya3/lm-evaluation-harness into hellaswag

Conflicts:
	lm_eval/tasks/hellaswag/hellaswag.yaml
parents 5be6a53d cc89d4f9
......@@ -41,7 +41,6 @@ def parse_args():
parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--output_base_path", type=str, default=None)
......@@ -78,12 +77,6 @@ def main():
eval_logger.info(f"Selected Tasks: {task_names}")
# TODO: description_dict?
# description_dict = {}
# if args.description_dict_path:
# with open(args.description_dict_path, "r") as f:
# description_dict = json.load(f)
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
......@@ -94,7 +87,6 @@ def main():
device=args.device,
no_cache=args.no_cache,
limit=args.limit,
# description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_out=args.write_out,
......
......@@ -13,12 +13,10 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--output_base_path", required=True)
parser.add_argument("--tasks", default="all_tasks")
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--sets", type=str, default="val") # example: val,test
parser.add_argument("--num_fewshot", type=int, default=1)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num_examples", type=int, default=1)
parser.add_argument("--description_dict_path", default=None)
return parser.parse_args()
......@@ -32,11 +30,6 @@ def main():
task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names)
# description_dict = {}
# if args.description_dict_path:
# with open(args.description_dict_path, "r") as f:
# description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True)
for task_name, task in task_dict.items():
rnd = random.Random()
......@@ -55,12 +48,6 @@ def main():
docs = join_iters(iters)
# description = (
# description_dict[task_name]
# if description_dict and task_name in description_dict
# else ""
# )
with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in (
zip(range(args.num_examples), docs)
......@@ -72,7 +59,6 @@ def main():
doc=doc,
num_fewshot=args.num_fewshot,
rnd=rnd,
# description=description,
)
f.write(ctx + "\n")
......
......@@ -28,6 +28,7 @@ setuptools.setup(
python_requires=">=3.9",
install_requires=[
"accelerate>=0.18.0",
"evaluate",
"datasets>=2.0.0",
"evaluate>=0.4.0",
"jsonlines",
......
......@@ -3,7 +3,7 @@ import lm_eval.tasks
import lm_eval.models
def test_description_dict():
def test_description():
seed = 42
num_examples = 1
task_names = ["hellaswag", "winogrande"]
......@@ -37,6 +37,5 @@ def test_description_dict():
doc=doc,
num_fewshot=1,
rnd=rnd,
description=description,
)
assert description in ctx
......@@ -55,7 +55,6 @@ def test_evaluator(taskname, task_class):
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None,
)
e2 = evaluator.evaluate(
lm=lm,
......@@ -63,7 +62,6 @@ def test_evaluator(taskname, task_class):
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None,
)
# check that caching is working
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment