Commit caab7820 authored by Baber's avatar Baber
Browse files

nit

parent 601be343
...@@ -357,98 +357,96 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -357,98 +357,96 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
parser = setup_parser() parser = setup_parser()
args = parse_eval_args(parser) args = parse_eval_args(parser)
config = EvaluationConfig.from_cli(args) cfg = EvaluationConfig.from_cli(args)
if args.wandb_args: if args.wandb_args:
wandb_logger = WandbLogger(config.wandb_args, config.wandb_config_args) wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args)
utils.setup_logging(config.verbosity) utils.setup_logging(cfg.verbosity)
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
# update the evaluation tracker args with the output path and the HF token # update the evaluation tracker args with the output path and the HF token
if config.output_path: if cfg.output_path:
config.hf_hub_log_args["output_path"] = config.output_path cfg.hf_hub_log_args["output_path"] = cfg.output_path
if os.environ.get("HF_TOKEN", None): if os.environ.get("HF_TOKEN", None):
config.hf_hub_log_args["token"] = os.environ.get("HF_TOKEN") cfg.hf_hub_log_args["token"] = os.environ.get("HF_TOKEN")
evaluation_tracker_args = config.hf_hub_log_args evaluation_tracker_args = cfg.hf_hub_log_args
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args) evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
if config.predict_only: if cfg.predict_only:
config.log_samples = True cfg.log_samples = True
if (config.log_samples or config.predict_only) and not config.output_path: if (cfg.log_samples or cfg.predict_only) and not cfg.output_path:
raise ValueError( raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only" "Specify --output_path if providing --log_samples or --predict_only"
) )
if config.fewshot_as_multiturn and config.apply_chat_template is False: if cfg.fewshot_as_multiturn and cfg.apply_chat_template is False:
raise ValueError( raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)." "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
) )
if config.include_path is not None: if cfg.include_path is not None:
eval_logger.info(f"Including path: {config.include_path}") eval_logger.info(f"Including path: {cfg.include_path}")
metadata = (config.model_args) | (config.metadata) metadata = (cfg.model_args) | (cfg.metadata)
config.metadata = metadata cfg.metadata = metadata
# task_manager = TaskManager(include_path=config["include_path"], metadata=metadata) # task_manager = TaskManager(include_path=config["include_path"], metadata=metadata)
task_manager = TaskManager(include_path=config.include_path, metadata=metadata) task_manager = TaskManager(include_path=cfg.include_path, metadata=metadata)
if "push_samples_to_hub" in evaluation_tracker_args and not config.log_samples: if "push_samples_to_hub" in evaluation_tracker_args and not cfg.log_samples:
eval_logger.warning( eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub." "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
) )
if config.limit: if cfg.limit:
eval_logger.warning( eval_logger.warning(
" --limit SHOULD ONLY BE USED FOR TESTING." " --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
if config.samples: if cfg.samples:
assert config.limit is None, ( assert cfg.limit is None, "If --samples is not None, then --limit must be None."
"If --samples is not None, then --limit must be None." if (samples := Path(cfg.samples)).is_file():
) cfg.samples = json.loads(samples.read_text())
if (samples := Path(config.samples)).is_file():
config.samples = json.loads(samples.read_text())
else: else:
config.samples = json.loads(config.samples) cfg.samples = json.loads(cfg.samples)
if config.tasks is None: if cfg.tasks is None:
eval_logger.error("Need to specify task to evaluate.") eval_logger.error("Need to specify task to evaluate.")
sys.exit() sys.exit()
elif config.tasks == "list": elif cfg.tasks == "list":
print(task_manager.list_all_tasks()) print(task_manager.list_all_tasks())
sys.exit() sys.exit()
elif config.tasks == "list_groups": elif cfg.tasks == "list_groups":
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False)) print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
sys.exit() sys.exit()
elif config.tasks == "list_tags": elif cfg.tasks == "list_tags":
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False)) print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
sys.exit() sys.exit()
elif config.tasks == "list_subtasks": elif cfg.tasks == "list_subtasks":
print(task_manager.list_all_tasks(list_groups=False, list_tags=False)) print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
sys.exit() sys.exit()
else: else:
if os.path.isdir(config.tasks): if os.path.isdir(cfg.tasks):
import glob import glob
task_names = [] task_names = []
yaml_path = os.path.join(config.tasks, "*.yaml") yaml_path = os.path.join(cfg.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path): for yaml_file in glob.glob(yaml_path):
config = utils.load_yaml_config(yaml_file) cfg = utils.load_yaml_config(yaml_file)
task_names.append(config) task_names.append(cfg)
else: else:
task_list = config.tasks.split(",") task_list = cfg.tasks.split(",")
task_names = task_manager.match_tasks(task_list) task_names = task_manager.match_tasks(task_list)
for task in [task for task in task_list if task not in task_names]: for task in [task for task in task_list if task not in task_names]:
if os.path.isfile(task): if os.path.isfile(task):
config = utils.load_yaml_config(task) cfg = utils.load_yaml_config(task)
task_names.append(config) task_names.append(cfg)
task_missing = [ task_missing = [
task for task in task_list if task not in task_names and "*" not in task task for task in task_list if task not in task_names and "*" not in task
] # we don't want errors if a wildcard ("*") task name was used ] # we don't want errors if a wildcard ("*") task name was used
...@@ -462,10 +460,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -462,10 +460,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
raise ValueError( raise ValueError(
f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues." f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
) )
config.tasks = task_names cfg.tasks = task_names
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if config.trust_remote_code: if cfg.trust_remote_code:
eval_logger.info( eval_logger.info(
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`" "Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
) )
...@@ -476,7 +474,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -476,7 +474,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
config.model_args["trust_remote_code"] = True cfg.model_args["trust_remote_code"] = True
( (
eval_logger.info(f"Selected Tasks: {task_names}") eval_logger.info(f"Selected Tasks: {task_names}")
if eval_logger.getEffectiveLevel() >= logging.INFO if eval_logger.getEffectiveLevel() >= logging.INFO
...@@ -484,66 +482,66 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -484,66 +482,66 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
) )
request_caching_args = request_caching_arg_to_dict( request_caching_args = request_caching_arg_to_dict(
cache_requests=config.cache_requests cache_requests=cfg.cache_requests
) )
config.request_caching_args = request_caching_args cfg.request_caching_args = request_caching_args
print(f"CONFIG_AFTER: {config}") print(f"CONFIG_AFTER: {cfg}")
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
model=config.model, model=cfg.model,
model_args=config.model_args, model_args=cfg.model_args,
tasks=config.tasks, tasks=cfg.tasks,
num_fewshot=config.num_fewshot, num_fewshot=cfg.num_fewshot,
batch_size=config.batch_size, batch_size=cfg.batch_size,
max_batch_size=config.max_batch_size, max_batch_size=cfg.max_batch_size,
device=config.device, device=cfg.device,
use_cache=config.use_cache, use_cache=cfg.use_cache,
cache_requests=config.request_caching_args.get("cache_requests", False), cache_requests=cfg.request_caching_args.get("cache_requests", False),
rewrite_requests_cache=config.request_caching_args.get( rewrite_requests_cache=cfg.request_caching_args.get(
"rewrite_requests_cache", False "rewrite_requests_cache", False
), ),
delete_requests_cache=config.request_caching_args.get( delete_requests_cache=cfg.request_caching_args.get(
"delete_requests_cache", False "delete_requests_cache", False
), ),
limit=config.limit, limit=cfg.limit,
samples=config.samples, samples=cfg.samples,
check_integrity=config.check_integrity, check_integrity=cfg.check_integrity,
write_out=config.write_out, write_out=cfg.write_out,
log_samples=config.log_samples, log_samples=cfg.log_samples,
evaluation_tracker=evaluation_tracker, evaluation_tracker=evaluation_tracker,
system_instruction=config.system_instruction, system_instruction=cfg.system_instruction,
apply_chat_template=config.apply_chat_template, apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=config.fewshot_as_multiturn, fewshot_as_multiturn=cfg.fewshot_as_multiturn,
gen_kwargs=config.gen_kwargs, gen_kwargs=cfg.gen_kwargs,
task_manager=task_manager, task_manager=task_manager,
verbosity=config.verbosity, verbosity=cfg.verbosity,
predict_only=config.predict_only, predict_only=cfg.predict_only,
random_seed=config.seed[0] if config.seed else None, random_seed=cfg.seed[0] if cfg.seed else None,
numpy_random_seed=config.seed[1] if config.seed else None, numpy_random_seed=cfg.seed[1] if cfg.seed else None,
torch_random_seed=config.seed[2] if config.seed else None, torch_random_seed=cfg.seed[2] if cfg.seed else None,
fewshot_random_seed=config.seed[3] if config.seed else None, fewshot_random_seed=cfg.seed[3] if cfg.seed else None,
confirm_run_unsafe_code=config.confirm_run_unsafe_code, confirm_run_unsafe_code=cfg.confirm_run_unsafe_code,
metadata=config.metadata, metadata=cfg.metadata,
) )
if results is not None: if results is not None:
if config.log_samples: if cfg.log_samples:
samples = results.pop("samples") samples = results.pop("samples")
dumped = json.dumps( dumped = json.dumps(
results, indent=2, default=handle_non_serializable, ensure_ascii=False results, indent=2, default=handle_non_serializable, ensure_ascii=False
) )
if config.show_config: if cfg.show_config:
print(dumped) print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
# Add W&B logging # Add W&B logging
if config.wandb_args: if cfg.wandb_args:
try: try:
wandb_logger.post_init(results) wandb_logger.post_init(results)
wandb_logger.log_eval_result() wandb_logger.log_eval_result()
if config.log_samples: if cfg.log_samples:
wandb_logger.log_eval_samples(samples) wandb_logger.log_eval_samples(samples)
except Exception as e: except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}") eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
...@@ -552,7 +550,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -552,7 +550,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
results=results, samples=samples if args.log_samples else None results=results, samples=samples if args.log_samples else None
) )
if config.log_samples: if cfg.log_samples:
for task_name, _ in results["configs"].items(): for task_name, _ in results["configs"].items():
evaluation_tracker.save_results_samples( evaluation_tracker.save_results_samples(
task_name=task_name, samples=samples[task_name] task_name=task_name, samples=samples[task_name]
...@@ -565,14 +563,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -565,14 +563,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
evaluation_tracker.recreate_metadata_card() evaluation_tracker.recreate_metadata_card()
print( print(
f"{config.model} ({config.model_args}), gen_kwargs: ({config.gen_kwargs}), limit: {config.limit}, num_fewshot: {config.num_fewshot}, " f"{cfg.model} ({cfg.model_args}), gen_kwargs: ({cfg.gen_kwargs}), limit: {cfg.limit}, num_fewshot: {cfg.num_fewshot}, "
f"batch_size: {config.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" f"batch_size: {cfg.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
) )
print(make_table(results)) print(make_table(results))
if "groups" in results: if "groups" in results:
print(make_table(results, "groups")) print(make_table(results, "groups"))
if config.wandb_args: if cfg.wandb_args:
# Tear down wandb run once all the logging is done. # Tear down wandb run once all the logging is done.
wandb_logger.run.finish() wandb_logger.run.finish()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment