"test/vscode:/vscode.git/clone" did not exist on "144bc70fcceede77fc2c2fbd286676b57f9a0c94"
Commit d6b14050 authored by Baber's avatar Baber
Browse files

use dataclass; don't pass config to `simple_evaluate`

parent 9c94fb2e
......@@ -491,9 +491,40 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
print(f"CONFIG_AFTER: {config}")
results = evaluator.simple_evaluate(
config=config,
model=config.model,
model_args=config.model_args,
tasks=config.tasks,
num_fewshot=config.num_fewshot,
batch_size=config.batch_size,
max_batch_size=config.max_batch_size,
device=config.device,
use_cache=config.use_cache,
cache_requests=config.request_caching_args.get("cache_requests", False),
rewrite_requests_cache=config.request_caching_args.get(
"rewrite_requests_cache", False
),
delete_requests_cache=config.request_caching_args.get(
"delete_requests_cache", False
),
limit=config.limit,
samples=config.samples,
check_integrity=config.check_integrity,
write_out=config.write_out,
log_samples=config.log_samples,
evaluation_tracker=evaluation_tracker,
system_instruction=config.system_instruction,
apply_chat_template=config.apply_chat_template,
fewshot_as_multiturn=config.fewshot_as_multiturn,
gen_kwargs=config.gen_kwargs,
task_manager=task_manager,
verbosity=config.verbosity,
predict_only=config.predict_only,
random_seed=config.seed[0] if config.seed else None,
numpy_random_seed=config.seed[1] if config.seed else None,
torch_random_seed=config.seed[2] if config.seed else None,
fewshot_random_seed=config.seed[3] if config.seed else None,
confirm_run_unsafe_code=config.confirm_run_unsafe_code,
metadata=config.metadata,
)
if results is not None:
......
import argparse
import os
from argparse import Namespace
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import yaml
from pydantic import BaseModel
from lm_eval.utils import simple_parse_args_string
......@@ -18,7 +18,8 @@ DICT_KEYS = [
]
class EvaluationConfig(BaseModel):
@dataclass
class EvaluationConfig:
"""
Simple config container for language-model evaluation.
No content validation here—just holds whatever comes from YAML or CLI.
......@@ -58,7 +59,9 @@ class EvaluationConfig(BaseModel):
request_caching_args: Optional[dict] = None
@staticmethod
def parse_namespace(namespace: argparse.Namespace) -> Dict[str, Any]:
def parse_namespace(
namespace: argparse.Namespace,
) -> tuple[Dict[str, Any], list[Dict[str, Any]]]:
"""
Convert an argparse Namespace object to a dictionary.
......@@ -159,7 +162,8 @@ class EvaluationConfig(BaseModel):
args_dict, config_data, explicit_args
)
# 4. Instantiate the Pydantic model (no further validation here)
# 4. Instantiate the config (no further validation here)
dict_config.pop("_explicit_args", None)
return cls(**dict_config)
......
......@@ -4,7 +4,7 @@ import logging
import random
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np
import torch
......@@ -13,7 +13,6 @@ import lm_eval.api.metrics
import lm_eval.api.registry
import lm_eval.api.task
import lm_eval.models
from lm_eval.api.eval_config import EvaluationConfig
from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import (
consolidate_group_results,
......@@ -46,11 +45,37 @@ eval_logger = logging.getLogger(__name__)
@positional_deprecated
def simple_evaluate(
config: "EvaluationConfig",
# TODO: bootstrap_iters is not passed from cli_evaluate
model,
model_args: Optional[Union[str, dict]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None,
batch_size: Optional[Union[int, str]] = None,
max_batch_size: Optional[int] = None,
device: Optional[str] = None,
use_cache: Optional[str] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
delete_requests_cache: bool = False,
limit: Optional[Union[int, float]] = None,
samples: Optional[dict] = None,
bootstrap_iters: int = 100000,
check_integrity: bool = False,
write_out: bool = False,
log_samples: bool = True,
evaluation_tracker: Optional[EvaluationTracker] = None,
system_instruction: Optional[str] = None,
apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False,
gen_kwargs: Union[str, dict, None] = None,
task_manager: Optional[TaskManager] = None,
verbosity=None,
predict_only: bool = False,
random_seed: int = 0,
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False,
metadata: Optional[dict] = None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -119,108 +144,110 @@ def simple_evaluate(
return
Dictionary of results
"""
if config.verbosity is not None:
setup_logging(verbosity=config.verbosity)
if verbosity is not None:
setup_logging(verbosity=verbosity)
start_date = time.time()
if config.limit is not None and config.samples is not None:
if limit is not None and samples is not None:
raise ValueError(
"Either 'limit' or 'samples' must be None, but both are not None."
)
if isinstance(config.model_args, str) and (
"instruct" in config.model_args and not config.apply_chat_template
):
if (
(isinstance(model_args, str) and "inst" in model_args.lower())
or (
isinstance(model_args, dict)
and any("inst" in str(v).lower() for v in model_args.values())
)
) and not apply_chat_template:
eval_logger.warning(
"Instruct model detected, but chat template not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
"Model appears to be an instruct variant but chat template is not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
)
if config.request_caching_args.get("delete_requests_cache", False):
if delete_requests_cache:
eval_logger.info("Deleting requests cache...")
delete_cache()
seed_message = []
if config.seed[0] is not None:
if random_seed is not None:
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
seed_message.append(f"Setting random seed to {config.seed[0]}")
random.seed(config.seed[0])
seed_message.append(f"Setting random seed to {random_seed}")
random.seed(random_seed)
if config.seed[1] is not None:
seed_message.append(f"Setting numpy seed to {config.seed[1]}")
np.random.seed(config.seed[1])
if numpy_random_seed is not None:
seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
np.random.seed(numpy_random_seed)
if config.seed[2] is not None:
seed_message.append(f"Setting torch manual seed to {config.seed[2]}")
torch.manual_seed(config.seed[2])
if torch_random_seed is not None:
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(torch_random_seed)
if config.seed[3] is not None:
seed_message.append(f"Setting fewshot manual seed to {config.seed[3]}")
if fewshot_random_seed is not None:
seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
if seed_message:
eval_logger.info(" | ".join(seed_message))
if config.tasks is None:
config.tasks = []
if len(config.tasks) == 0:
if tasks is None:
tasks = []
if len(tasks) == 0:
raise ValueError(
"No tasks specified, or no tasks found. Please verify the task names."
)
if config.gen_kwargs is not None:
if isinstance(config.gen_kwargs, str):
config.gen_kwargs = simple_parse_args_string(config.gen_kwargs)
if gen_kwargs is not None:
if isinstance(gen_kwargs, str):
gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning(
f"generation_kwargs: {config.gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
"Ensure 'do_sample=True' for non-greedy decoding!"
)
if not config.gen_kwargs:
config.gen_kwargs = None
if not gen_kwargs:
gen_kwargs = None
if isinstance(config.model, str):
if config.model_args is None:
if isinstance(model, str):
if model_args is None:
eval_logger.warning("model_args not specified. Using defaults.")
config.model_args = ""
model_args = ""
if isinstance(config.model_args, dict):
if isinstance(model_args, dict):
eval_logger.info(
f"Initializing {config.model} model, with arguments: {config.model_args}"
f"Initializing {model} model, with arguments: {model_args}"
)
lm = lm_eval.api.registry.get_model(config.model).create_from_arg_obj(
config.model_args,
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
model_args,
{
"batch_size": config.batch_size,
"max_batch_size": config.max_batch_size,
"device": config.device,
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
},
)
else:
eval_logger.info(
f"Initializing {config.model} model, with arguments: {simple_parse_args_string(config.model_args)}"
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
)
lm = lm_eval.api.registry.get_model(config.model).create_from_arg_string(
config.model_args,
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args,
{
"batch_size": config.batch_size,
"max_batch_size": config.max_batch_size,
"device": config.device,
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
},
)
else:
if not isinstance(config.model, lm_eval.api.model.LM):
if not isinstance(model, lm_eval.api.model.LM):
raise TypeError(
f"The value of `model` passed to simple_evaluate() was of type {type(config.model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
)
eval_logger.info("Using pre-initialized model")
lm = config.model
lm = model
if config.use_cache is not None:
eval_logger.info(
f"Using cache at {config.use_cache + '_rank' + str(lm.rank) + '.db'}"
)
if use_cache is not None:
eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM(
lm,
config.use_cache
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ "_rank"
......@@ -229,10 +256,17 @@ def simple_evaluate(
)
if task_manager is None:
task_manager = TaskManager(metadata=config.metadata)
metadata = (
simple_parse_args_string(model_args)
if isinstance(model_args, str)
else model_args
if isinstance(model_args, dict)
else {}
) | (metadata or {})
task_manager = TaskManager(metadata=metadata)
task_dict = get_task_dict(
config.tasks,
tasks,
task_manager,
)
......@@ -249,17 +283,15 @@ def simple_evaluate(
else:
if task_obj.get_config("output_type") == "generate_until":
if config.gen_kwargs is not None:
if gen_kwargs is not None:
task_obj.set_config(
key="generation_kwargs",
value=config.gen_kwargs,
update=True,
key="generation_kwargs", value=gen_kwargs, update=True
)
eval_logger.info(
f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
)
if config.predict_only:
if predict_only:
eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
)
......@@ -268,16 +300,16 @@ def simple_evaluate(
# override tasks' fewshot values to the provided num_fewshot arg value
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that
if config.num_fewshot is not None:
if num_fewshot is not None:
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
)
else:
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {config.num_fewshot}"
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj.set_config(key="num_fewshot", value=config.num_fewshot)
task_obj.set_config(key="num_fewshot", value=num_fewshot)
else:
# if num_fewshot not provided, and the task does not define a default one, default to 0
if (
......@@ -285,7 +317,7 @@ def simple_evaluate(
) is None:
task_obj.set_config(key="num_fewshot", value=0)
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
task_obj.set_fewshot_seed(seed=config.seed[3])
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
adjusted_task_dict[task_name] = task_obj
......@@ -293,55 +325,51 @@ def simple_evaluate(
task_dict = _adjust_config(task_dict)
if config.check_integrity:
run_task_tests(task_list=config.tasks)
if check_integrity:
run_task_tests(task_list=tasks)
if evaluation_tracker is not None:
evaluation_tracker.general_config_tracker.log_experiment_args(
model_source=config.model,
model_args=config.model_args,
system_instruction=config.system_instruction,
chat_template=lm.chat_template(config.apply_chat_template)
if config.apply_chat_template
model_source=model,
model_args=model_args,
system_instruction=system_instruction,
chat_template=lm.chat_template(apply_chat_template)
if apply_chat_template
else None,
fewshot_as_multiturn=config.fewshot_as_multiturn,
fewshot_as_multiturn=fewshot_as_multiturn,
)
results = evaluate(
lm=lm,
task_dict=task_dict,
limit=config.limit,
samples=config.samples,
cache_requests=config.cache_requests,
rewrite_requests_cache=config.request_caching_args.get(
"rewrite_requests_cache", False
),
limit=limit,
samples=samples,
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters,
write_out=config.write_out,
log_samples=True if config.predict_only else config.log_samples,
system_instruction=config.system_instruction,
apply_chat_template=config.apply_chat_template,
fewshot_as_multiturn=config.fewshot_as_multiturn,
verbosity=config.verbosity,
confirm_run_unsafe_code=config.confirm_run_unsafe_code,
write_out=write_out,
log_samples=True if predict_only else log_samples,
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
verbosity=verbosity,
confirm_run_unsafe_code=confirm_run_unsafe_code,
)
if config.verbosity is not None:
setup_logging(verbosity=config.verbosity)
if verbosity is not None:
setup_logging(verbosity=verbosity)
if lm.rank == 0:
if isinstance(config.model, str):
model_name = config.model
elif hasattr(config.model, "config") and hasattr(
config.model.config, "_name_or_path"
):
model_name = config.model.config._name_or_path
if isinstance(model, str):
model_name = model
elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
model_name = model.config._name_or_path
else:
model_name = type(config.model).__name__
model_name = type(model).__name__
# add info about the model and few shot config
results["config"] = {
"model": model_name,
"model_args": config.model_args,
"model_args": model_args,
}
# add more detailed model info if available
if isinstance(lm, lm_eval.models.huggingface.HFLM):
......@@ -349,19 +377,19 @@ def simple_evaluate(
# add info about execution
results["config"].update(
{
"batch_size": config.batch_size,
"batch_size": batch_size,
"batch_sizes": (
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
),
"device": config.device,
"use_cache": config.use_cache,
"limit": config.limit,
"device": device,
"use_cache": use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"gen_kwargs": config.gen_kwargs,
"random_seed": config.seed[0],
"numpy_seed": config.seed[1],
"torch_seed": config.seed[2],
"fewshot_seed": config.seed[3],
"gen_kwargs": gen_kwargs,
"random_seed": random_seed,
"numpy_seed": numpy_random_seed,
"torch_seed": torch_random_seed,
"fewshot_seed": fewshot_random_seed,
}
)
results["git_hash"] = get_git_commit_hash()
......@@ -459,7 +487,7 @@ def evaluate(
for task_output in eval_tasks:
task: Task = task_output.task
if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
if getattr(task, "MULTIMODAL", False) and not getattr(lm, "MULTIMODAL", False):
incompatible_tasks.append(task_output.task_name)
elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
raise ValueError(
......@@ -470,10 +498,6 @@ def evaluate(
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
)
else:
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
)
# end validation check
# Cache the limit arg.
......@@ -731,11 +755,11 @@ def evaluate(
return None
# def request_caching_arg_to_dict(cache_requests: str) -> dict:
# request_caching_args = {
# "cache_requests": cache_requests in {"true", "refresh"},
# "rewrite_requests_cache": cache_requests == "refresh",
# "delete_requests_cache": cache_requests == "delete",
# }
def request_caching_arg_to_dict(cache_requests: str) -> dict:
request_caching_args = {
"cache_requests": cache_requests in {"true", "refresh"},
"rewrite_requests_cache": cache_requests == "refresh",
"delete_requests_cache": cache_requests == "delete",
}
# return request_caching_args
return request_caching_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment