Commit 7a058d69 authored by Leo Gao's avatar Leo Gao
Browse files

Improve interface and upgrade results dict format

parent b569f483
...@@ -2,9 +2,43 @@ import collections ...@@ -2,9 +2,43 @@ import collections
import itertools import itertools
import random import random
import lm_eval.metrics import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
import numpy as np
def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=None, device=None, no_cache=False, limit=None, bootstrap_iters=100000):
random.seed(1234)
np.random.seed(1234)
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
'batch_size': batch_size, 'device': device
})
if not no_cache:
lm = lm_eval.base.CachingLM(lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db')
task_dict = lm_eval.tasks.get_task_dict(task_names)
results = evaluate(lm, task_dict, False, num_fewshot, limit)
# add info about the model and few shot config
results["config"] = {
"model": model,
"model_args": model_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size,
"device": device,
"no_cache": no_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters
}
return results
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000): def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000):
assert not provide_description # not implemented. todo: implement proper description-providing system
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())] task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
...@@ -100,6 +134,38 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -100,6 +134,38 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
results[task_name][metric + "_stderr"] = stderr(items) results[task_name][metric + "_stderr"] = stderr(items)
return { return {
"results": results, "results": dict(results),
"versions": versions "versions": dict(versions)
} }
def make_table(result_dict):
from pytablewriter import MarkdownTableWriter, LatexTableWriter
md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
values = []
for k, dic in result_dict["results"].items():
version = result_dict["versions"][k]
for m, v in dic.items():
if m.endswith("_stderr"): continue
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
else:
values.append([k, version, m, '%.4f' % v, '', ''])
k = ""
version = ""
md_writer.value_matrix = values
latex_writer.value_matrix = values
# todo: make latex table look good
# print(latex_writer.dumps())
return md_writer.dumps()
\ No newline at end of file
...@@ -17,7 +17,6 @@ def parse_args(): ...@@ -17,7 +17,6 @@ def parse_args():
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=str, default=None) parser.add_argument('--device', type=str, default=None)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None) parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument('--no_cache', action="store_true")
...@@ -26,63 +25,29 @@ def parse_args(): ...@@ -26,63 +25,29 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
lm = models.get_model(args.model).create_from_arg_string(args.model_args, { assert not args.provide_description # not implemented
'batch_size': args.batch_size, 'device': args.device
})
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if not args.no_cache:
lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db')
if args.tasks == "all_tasks": if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
else: else:
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names)
results = evaluator.evaluate(lm, task_dict, args.provide_description, args.num_fewshot, args.limit) results = evaluator.simple_evaluate(args.model, args.model_args, task_names, args.num_fewshot, args.batch_size, args.device, args.no_cache, args.limit)
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
if args.output_path: if args.output_path:
with open(args.output_path, "w") as f: with open(args.output_path, "w") as f:
f.write(dumped) f.write(dumped)
# MAKE TABLE
from pytablewriter import MarkdownTableWriter, LatexTableWriter
md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
values = []
for k, dic in results["results"].items():
version = results["versions"][k]
for m, v in dic.items():
if m.endswith("_stderr"): continue
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
else:
values.append([k, version, m, '%.4f' % v, '', ''])
k = ""
version = ""
md_writer.value_matrix = values
latex_writer.value_matrix = values
# todo: make latex table look good
# print(latex_writer.dumps())
print(f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}") print(f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}")
print(md_writer.dumps()) print(evaluator.make_table(results))
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment