Commit 7b2b2a23 authored by Leo Gao's avatar Leo Gao
Browse files

Make simple_evaluate take LM and Task objects directly too

parent aea963a1
......@@ -10,18 +10,19 @@ from lm_eval.utils import positional_deprecated
@positional_deprecated
def simple_evaluate(model, model_args, task_names,
def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None):
"""Instantiate and evaluate a model on a list of tasks.
:param model: str
Name of model, see lm_eval.models.get_model
:param model_args: str
String arguments for each model class, see LM.create_from_arg_string
:param task_names: list[str]
List of task names
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]]
List of task names or Task objects
:param num_fewshot: int
Number of examples in few-shot context
:param batch_size: int, optional
......@@ -42,16 +43,23 @@ def simple_evaluate(model, model_args, task_names,
random.seed(1234)
np.random.seed(1234)
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
'batch_size': batch_size, 'device': device
})
assert tasks != [], "No tasks specified"
if isinstance(model, str):
if model_args is None: model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
'batch_size': batch_size, 'device': device
})
else:
assert isinstance(model, lm_eval.base.LM)
lm = model
if not no_cache:
lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
)
task_dict = lm_eval.tasks.get_task_dict(task_names)
task_dict = lm_eval.tasks.get_task_dict(tasks)
results = evaluate(
lm=lm,
......
from pprint import pprint
from typing import List, Union
import sacrebleu
import lm_eval.base
from . import superglue
from . import glue
......@@ -232,8 +234,20 @@ def get_task(task_name):
raise KeyError(f"Missing task {task_name}")
def get_task_dict(task_name_list):
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else task_object.__name__
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
return {
task_name: get_task(task_name)()
for task_name in task_name_list
for task_name in task_name_list if isinstance(task_name, str)
} + {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment