Commit d6b14050 authored by Baber's avatar Baber
Browse files

use dataclass; don't pass config to `simple_evaluate`

parent 9c94fb2e
...@@ -491,9 +491,40 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -491,9 +491,40 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
print(f"CONFIG_AFTER: {config}") print(f"CONFIG_AFTER: {config}")
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
config=config, model=config.model,
model_args=config.model_args,
tasks=config.tasks,
num_fewshot=config.num_fewshot,
batch_size=config.batch_size,
max_batch_size=config.max_batch_size,
device=config.device,
use_cache=config.use_cache,
cache_requests=config.request_caching_args.get("cache_requests", False),
rewrite_requests_cache=config.request_caching_args.get(
"rewrite_requests_cache", False
),
delete_requests_cache=config.request_caching_args.get(
"delete_requests_cache", False
),
limit=config.limit,
samples=config.samples,
check_integrity=config.check_integrity,
write_out=config.write_out,
log_samples=config.log_samples,
evaluation_tracker=evaluation_tracker, evaluation_tracker=evaluation_tracker,
system_instruction=config.system_instruction,
apply_chat_template=config.apply_chat_template,
fewshot_as_multiturn=config.fewshot_as_multiturn,
gen_kwargs=config.gen_kwargs,
task_manager=task_manager, task_manager=task_manager,
verbosity=config.verbosity,
predict_only=config.predict_only,
random_seed=config.seed[0] if config.seed else None,
numpy_random_seed=config.seed[1] if config.seed else None,
torch_random_seed=config.seed[2] if config.seed else None,
fewshot_random_seed=config.seed[3] if config.seed else None,
confirm_run_unsafe_code=config.confirm_run_unsafe_code,
metadata=config.metadata,
) )
if results is not None: if results is not None:
......
import argparse import argparse
import os import os
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import yaml import yaml
from pydantic import BaseModel
from lm_eval.utils import simple_parse_args_string from lm_eval.utils import simple_parse_args_string
...@@ -18,7 +18,8 @@ DICT_KEYS = [ ...@@ -18,7 +18,8 @@ DICT_KEYS = [
] ]
class EvaluationConfig(BaseModel): @dataclass
class EvaluationConfig:
""" """
Simple config container for language-model evaluation. Simple config container for language-model evaluation.
No content validation here—just holds whatever comes from YAML or CLI. No content validation here—just holds whatever comes from YAML or CLI.
...@@ -58,7 +59,9 @@ class EvaluationConfig(BaseModel): ...@@ -58,7 +59,9 @@ class EvaluationConfig(BaseModel):
request_caching_args: Optional[dict] = None request_caching_args: Optional[dict] = None
@staticmethod @staticmethod
def parse_namespace(namespace: argparse.Namespace) -> Dict[str, Any]: def parse_namespace(
namespace: argparse.Namespace,
) -> tuple[Dict[str, Any], list[Dict[str, Any]]]:
""" """
Convert an argparse Namespace object to a dictionary. Convert an argparse Namespace object to a dictionary.
...@@ -159,7 +162,8 @@ class EvaluationConfig(BaseModel): ...@@ -159,7 +162,8 @@ class EvaluationConfig(BaseModel):
args_dict, config_data, explicit_args args_dict, config_data, explicit_args
) )
# 4. Instantiate the Pydantic model (no further validation here) # 4. Instantiate the config (no further validation here)
dict_config.pop("_explicit_args", None)
return cls(**dict_config) return cls(**dict_config)
......
...@@ -4,7 +4,7 @@ import logging ...@@ -4,7 +4,7 @@ import logging
import random import random
import time import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -13,7 +13,6 @@ import lm_eval.api.metrics ...@@ -13,7 +13,6 @@ import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
import lm_eval.api.task import lm_eval.api.task
import lm_eval.models import lm_eval.models
from lm_eval.api.eval_config import EvaluationConfig
from lm_eval.caching.cache import delete_cache from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import ( from lm_eval.evaluator_utils import (
consolidate_group_results, consolidate_group_results,
...@@ -46,11 +45,37 @@ eval_logger = logging.getLogger(__name__) ...@@ -46,11 +45,37 @@ eval_logger = logging.getLogger(__name__)
@positional_deprecated @positional_deprecated
def simple_evaluate( def simple_evaluate(
config: "EvaluationConfig", model,
# TODO: bootstrap_iters is not passed from cli_evaluate model_args: Optional[Union[str, dict]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None,
batch_size: Optional[Union[int, str]] = None,
max_batch_size: Optional[int] = None,
device: Optional[str] = None,
use_cache: Optional[str] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
delete_requests_cache: bool = False,
limit: Optional[Union[int, float]] = None,
samples: Optional[dict] = None,
bootstrap_iters: int = 100000, bootstrap_iters: int = 100000,
check_integrity: bool = False,
write_out: bool = False,
log_samples: bool = True,
evaluation_tracker: Optional[EvaluationTracker] = None, evaluation_tracker: Optional[EvaluationTracker] = None,
system_instruction: Optional[str] = None,
apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False,
gen_kwargs: Union[str, dict, None] = None,
task_manager: Optional[TaskManager] = None, task_manager: Optional[TaskManager] = None,
verbosity=None,
predict_only: bool = False,
random_seed: int = 0,
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False,
metadata: Optional[dict] = None,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -119,108 +144,110 @@ def simple_evaluate( ...@@ -119,108 +144,110 @@ def simple_evaluate(
return return
Dictionary of results Dictionary of results
""" """
if config.verbosity is not None: if verbosity is not None:
setup_logging(verbosity=config.verbosity) setup_logging(verbosity=verbosity)
start_date = time.time() start_date = time.time()
if config.limit is not None and config.samples is not None: if limit is not None and samples is not None:
raise ValueError( raise ValueError(
"Either 'limit' or 'samples' must be None, but both are not None." "Either 'limit' or 'samples' must be None, but both are not None."
) )
if isinstance(config.model_args, str) and ( if (
"instruct" in config.model_args and not config.apply_chat_template (isinstance(model_args, str) and "inst" in model_args.lower())
): or (
isinstance(model_args, dict)
and any("inst" in str(v).lower() for v in model_args.values())
)
) and not apply_chat_template:
eval_logger.warning( eval_logger.warning(
"Instruct model detected, but chat template not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)." "Model appears to be an instruct variant but chat template is not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
) )
if config.request_caching_args.get("delete_requests_cache", False): if delete_requests_cache:
eval_logger.info("Deleting requests cache...") eval_logger.info("Deleting requests cache...")
delete_cache() delete_cache()
seed_message = [] seed_message = []
if config.seed[0] is not None: if random_seed is not None:
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412 # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
seed_message.append(f"Setting random seed to {config.seed[0]}") seed_message.append(f"Setting random seed to {random_seed}")
random.seed(config.seed[0]) random.seed(random_seed)
if config.seed[1] is not None: if numpy_random_seed is not None:
seed_message.append(f"Setting numpy seed to {config.seed[1]}") seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
np.random.seed(config.seed[1]) np.random.seed(numpy_random_seed)
if config.seed[2] is not None: if torch_random_seed is not None:
seed_message.append(f"Setting torch manual seed to {config.seed[2]}") seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(config.seed[2]) torch.manual_seed(torch_random_seed)
if config.seed[3] is not None: if fewshot_random_seed is not None:
seed_message.append(f"Setting fewshot manual seed to {config.seed[3]}") seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
if seed_message: if seed_message:
eval_logger.info(" | ".join(seed_message)) eval_logger.info(" | ".join(seed_message))
if config.tasks is None: if tasks is None:
config.tasks = [] tasks = []
if len(config.tasks) == 0: if len(tasks) == 0:
raise ValueError( raise ValueError(
"No tasks specified, or no tasks found. Please verify the task names." "No tasks specified, or no tasks found. Please verify the task names."
) )
if config.gen_kwargs is not None: if gen_kwargs is not None:
if isinstance(config.gen_kwargs, str): if isinstance(gen_kwargs, str):
config.gen_kwargs = simple_parse_args_string(config.gen_kwargs) gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning( eval_logger.warning(
f"generation_kwargs: {config.gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. " f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
"Ensure 'do_sample=True' for non-greedy decoding!" "Ensure 'do_sample=True' for non-greedy decoding!"
) )
if not config.gen_kwargs: if not gen_kwargs:
config.gen_kwargs = None gen_kwargs = None
if isinstance(config.model, str): if isinstance(model, str):
if config.model_args is None: if model_args is None:
eval_logger.warning("model_args not specified. Using defaults.") eval_logger.warning("model_args not specified. Using defaults.")
config.model_args = "" model_args = ""
if isinstance(config.model_args, dict): if isinstance(model_args, dict):
eval_logger.info( eval_logger.info(
f"Initializing {config.model} model, with arguments: {config.model_args}" f"Initializing {model} model, with arguments: {model_args}"
) )
lm = lm_eval.api.registry.get_model(config.model).create_from_arg_obj( lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
config.model_args, model_args,
{ {
"batch_size": config.batch_size, "batch_size": batch_size,
"max_batch_size": config.max_batch_size, "max_batch_size": max_batch_size,
"device": config.device, "device": device,
}, },
) )
else: else:
eval_logger.info( eval_logger.info(
f"Initializing {config.model} model, with arguments: {simple_parse_args_string(config.model_args)}" f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
) )
lm = lm_eval.api.registry.get_model(config.model).create_from_arg_string( lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
config.model_args, model_args,
{ {
"batch_size": config.batch_size, "batch_size": batch_size,
"max_batch_size": config.max_batch_size, "max_batch_size": max_batch_size,
"device": config.device, "device": device,
}, },
) )
else: else:
if not isinstance(config.model, lm_eval.api.model.LM): if not isinstance(model, lm_eval.api.model.LM):
raise TypeError( raise TypeError(
f"The value of `model` passed to simple_evaluate() was of type {type(config.model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first." f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
) )
eval_logger.info("Using pre-initialized model") eval_logger.info("Using pre-initialized model")
lm = config.model lm = model
if config.use_cache is not None: if use_cache is not None:
eval_logger.info( eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
f"Using cache at {config.use_cache + '_rank' + str(lm.rank) + '.db'}"
)
lm = lm_eval.api.model.CachingLM( lm = lm_eval.api.model.CachingLM(
lm, lm,
config.use_cache use_cache
# each rank receives a different cache db. # each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once # necessary to avoid multiple writes to cache at once
+ "_rank" + "_rank"
...@@ -229,10 +256,17 @@ def simple_evaluate( ...@@ -229,10 +256,17 @@ def simple_evaluate(
) )
if task_manager is None: if task_manager is None:
task_manager = TaskManager(metadata=config.metadata) metadata = (
simple_parse_args_string(model_args)
if isinstance(model_args, str)
else model_args
if isinstance(model_args, dict)
else {}
) | (metadata or {})
task_manager = TaskManager(metadata=metadata)
task_dict = get_task_dict( task_dict = get_task_dict(
config.tasks, tasks,
task_manager, task_manager,
) )
...@@ -249,17 +283,15 @@ def simple_evaluate( ...@@ -249,17 +283,15 @@ def simple_evaluate(
else: else:
if task_obj.get_config("output_type") == "generate_until": if task_obj.get_config("output_type") == "generate_until":
if config.gen_kwargs is not None: if gen_kwargs is not None:
task_obj.set_config( task_obj.set_config(
key="generation_kwargs", key="generation_kwargs", value=gen_kwargs, update=True
value=config.gen_kwargs,
update=True,
) )
eval_logger.info( eval_logger.info(
f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}" f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
) )
if config.predict_only: if predict_only:
eval_logger.info( eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!" f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
) )
...@@ -268,16 +300,16 @@ def simple_evaluate( ...@@ -268,16 +300,16 @@ def simple_evaluate(
# override tasks' fewshot values to the provided num_fewshot arg value # override tasks' fewshot values to the provided num_fewshot arg value
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that # except if tasks have it set to 0 manually in their configs--then we should never overwrite that
if config.num_fewshot is not None: if num_fewshot is not None:
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0: if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
eval_logger.info( eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored." f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
) )
else: else:
eval_logger.warning( eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {config.num_fewshot}" f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
) )
task_obj.set_config(key="num_fewshot", value=config.num_fewshot) task_obj.set_config(key="num_fewshot", value=num_fewshot)
else: else:
# if num_fewshot not provided, and the task does not define a default one, default to 0 # if num_fewshot not provided, and the task does not define a default one, default to 0
if ( if (
...@@ -285,7 +317,7 @@ def simple_evaluate( ...@@ -285,7 +317,7 @@ def simple_evaluate(
) is None: ) is None:
task_obj.set_config(key="num_fewshot", value=0) task_obj.set_config(key="num_fewshot", value=0)
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file) # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
task_obj.set_fewshot_seed(seed=config.seed[3]) task_obj.set_fewshot_seed(seed=fewshot_random_seed)
adjusted_task_dict[task_name] = task_obj adjusted_task_dict[task_name] = task_obj
...@@ -293,55 +325,51 @@ def simple_evaluate( ...@@ -293,55 +325,51 @@ def simple_evaluate(
task_dict = _adjust_config(task_dict) task_dict = _adjust_config(task_dict)
if config.check_integrity: if check_integrity:
run_task_tests(task_list=config.tasks) run_task_tests(task_list=tasks)
if evaluation_tracker is not None: if evaluation_tracker is not None:
evaluation_tracker.general_config_tracker.log_experiment_args( evaluation_tracker.general_config_tracker.log_experiment_args(
model_source=config.model, model_source=model,
model_args=config.model_args, model_args=model_args,
system_instruction=config.system_instruction, system_instruction=system_instruction,
chat_template=lm.chat_template(config.apply_chat_template) chat_template=lm.chat_template(apply_chat_template)
if config.apply_chat_template if apply_chat_template
else None, else None,
fewshot_as_multiturn=config.fewshot_as_multiturn, fewshot_as_multiturn=fewshot_as_multiturn,
) )
results = evaluate( results = evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
limit=config.limit, limit=limit,
samples=config.samples, samples=samples,
cache_requests=config.cache_requests, cache_requests=cache_requests,
rewrite_requests_cache=config.request_caching_args.get( rewrite_requests_cache=rewrite_requests_cache,
"rewrite_requests_cache", False
),
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
write_out=config.write_out, write_out=write_out,
log_samples=True if config.predict_only else config.log_samples, log_samples=True if predict_only else log_samples,
system_instruction=config.system_instruction, system_instruction=system_instruction,
apply_chat_template=config.apply_chat_template, apply_chat_template=apply_chat_template,
fewshot_as_multiturn=config.fewshot_as_multiturn, fewshot_as_multiturn=fewshot_as_multiturn,
verbosity=config.verbosity, verbosity=verbosity,
confirm_run_unsafe_code=config.confirm_run_unsafe_code, confirm_run_unsafe_code=confirm_run_unsafe_code,
) )
if config.verbosity is not None: if verbosity is not None:
setup_logging(verbosity=config.verbosity) setup_logging(verbosity=verbosity)
if lm.rank == 0: if lm.rank == 0:
if isinstance(config.model, str): if isinstance(model, str):
model_name = config.model model_name = model
elif hasattr(config.model, "config") and hasattr( elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
config.model.config, "_name_or_path" model_name = model.config._name_or_path
):
model_name = config.model.config._name_or_path
else: else:
model_name = type(config.model).__name__ model_name = type(model).__name__
# add info about the model and few shot config # add info about the model and few shot config
results["config"] = { results["config"] = {
"model": model_name, "model": model_name,
"model_args": config.model_args, "model_args": model_args,
} }
# add more detailed model info if available # add more detailed model info if available
if isinstance(lm, lm_eval.models.huggingface.HFLM): if isinstance(lm, lm_eval.models.huggingface.HFLM):
...@@ -349,19 +377,19 @@ def simple_evaluate( ...@@ -349,19 +377,19 @@ def simple_evaluate(
# add info about execution # add info about execution
results["config"].update( results["config"].update(
{ {
"batch_size": config.batch_size, "batch_size": batch_size,
"batch_sizes": ( "batch_sizes": (
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else [] list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
), ),
"device": config.device, "device": device,
"use_cache": config.use_cache, "use_cache": use_cache,
"limit": config.limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
"gen_kwargs": config.gen_kwargs, "gen_kwargs": gen_kwargs,
"random_seed": config.seed[0], "random_seed": random_seed,
"numpy_seed": config.seed[1], "numpy_seed": numpy_random_seed,
"torch_seed": config.seed[2], "torch_seed": torch_random_seed,
"fewshot_seed": config.seed[3], "fewshot_seed": fewshot_random_seed,
} }
) )
results["git_hash"] = get_git_commit_hash() results["git_hash"] = get_git_commit_hash()
...@@ -459,7 +487,7 @@ def evaluate( ...@@ -459,7 +487,7 @@ def evaluate(
for task_output in eval_tasks: for task_output in eval_tasks:
task: Task = task_output.task task: Task = task_output.task
if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False): if getattr(task, "MULTIMODAL", False) and not getattr(lm, "MULTIMODAL", False):
incompatible_tasks.append(task_output.task_name) incompatible_tasks.append(task_output.task_name)
elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code: elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
raise ValueError( raise ValueError(
...@@ -470,10 +498,6 @@ def evaluate( ...@@ -470,10 +498,6 @@ def evaluate(
raise ValueError( raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type." f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
) )
else:
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
)
# end validation check # end validation check
# Cache the limit arg. # Cache the limit arg.
...@@ -731,11 +755,11 @@ def evaluate( ...@@ -731,11 +755,11 @@ def evaluate(
return None return None
# def request_caching_arg_to_dict(cache_requests: str) -> dict: def request_caching_arg_to_dict(cache_requests: str) -> dict:
# request_caching_args = { request_caching_args = {
# "cache_requests": cache_requests in {"true", "refresh"}, "cache_requests": cache_requests in {"true", "refresh"},
# "rewrite_requests_cache": cache_requests == "refresh", "rewrite_requests_cache": cache_requests == "refresh",
# "delete_requests_cache": cache_requests == "delete", "delete_requests_cache": cache_requests == "delete",
# } }
# return request_caching_args return request_caching_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment