Commit caab7820 authored by Baber's avatar Baber
Browse files

nit

parent 601be343
......@@ -357,98 +357,96 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
parser = setup_parser()
args = parse_eval_args(parser)
config = EvaluationConfig.from_cli(args)
cfg = EvaluationConfig.from_cli(args)
if args.wandb_args:
wandb_logger = WandbLogger(config.wandb_args, config.wandb_config_args)
wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args)
utils.setup_logging(config.verbosity)
utils.setup_logging(cfg.verbosity)
eval_logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# update the evaluation tracker args with the output path and the HF token
if config.output_path:
config.hf_hub_log_args["output_path"] = config.output_path
if cfg.output_path:
cfg.hf_hub_log_args["output_path"] = cfg.output_path
if os.environ.get("HF_TOKEN", None):
config.hf_hub_log_args["token"] = os.environ.get("HF_TOKEN")
cfg.hf_hub_log_args["token"] = os.environ.get("HF_TOKEN")
evaluation_tracker_args = config.hf_hub_log_args
evaluation_tracker_args = cfg.hf_hub_log_args
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
if config.predict_only:
config.log_samples = True
if cfg.predict_only:
cfg.log_samples = True
if (config.log_samples or config.predict_only) and not config.output_path:
if (cfg.log_samples or cfg.predict_only) and not cfg.output_path:
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
if config.fewshot_as_multiturn and config.apply_chat_template is False:
if cfg.fewshot_as_multiturn and cfg.apply_chat_template is False:
raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
)
if config.include_path is not None:
eval_logger.info(f"Including path: {config.include_path}")
if cfg.include_path is not None:
eval_logger.info(f"Including path: {cfg.include_path}")
metadata = (config.model_args) | (config.metadata)
config.metadata = metadata
metadata = (cfg.model_args) | (cfg.metadata)
cfg.metadata = metadata
# task_manager = TaskManager(include_path=config["include_path"], metadata=metadata)
task_manager = TaskManager(include_path=config.include_path, metadata=metadata)
task_manager = TaskManager(include_path=cfg.include_path, metadata=metadata)
if "push_samples_to_hub" in evaluation_tracker_args and not config.log_samples:
if "push_samples_to_hub" in evaluation_tracker_args and not cfg.log_samples:
eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
)
if config.limit:
if cfg.limit:
eval_logger.warning(
" --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if config.samples:
assert config.limit is None, (
"If --samples is not None, then --limit must be None."
)
if (samples := Path(config.samples)).is_file():
config.samples = json.loads(samples.read_text())
if cfg.samples:
assert cfg.limit is None, "If --samples is not None, then --limit must be None."
if (samples := Path(cfg.samples)).is_file():
cfg.samples = json.loads(samples.read_text())
else:
config.samples = json.loads(config.samples)
cfg.samples = json.loads(cfg.samples)
if config.tasks is None:
if cfg.tasks is None:
eval_logger.error("Need to specify task to evaluate.")
sys.exit()
elif config.tasks == "list":
elif cfg.tasks == "list":
print(task_manager.list_all_tasks())
sys.exit()
elif config.tasks == "list_groups":
elif cfg.tasks == "list_groups":
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
sys.exit()
elif config.tasks == "list_tags":
elif cfg.tasks == "list_tags":
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
sys.exit()
elif config.tasks == "list_subtasks":
elif cfg.tasks == "list_subtasks":
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
sys.exit()
else:
if os.path.isdir(config.tasks):
if os.path.isdir(cfg.tasks):
import glob
task_names = []
yaml_path = os.path.join(config.tasks, "*.yaml")
yaml_path = os.path.join(cfg.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path):
config = utils.load_yaml_config(yaml_file)
task_names.append(config)
cfg = utils.load_yaml_config(yaml_file)
task_names.append(cfg)
else:
task_list = config.tasks.split(",")
task_list = cfg.tasks.split(",")
task_names = task_manager.match_tasks(task_list)
for task in [task for task in task_list if task not in task_names]:
if os.path.isfile(task):
config = utils.load_yaml_config(task)
task_names.append(config)
cfg = utils.load_yaml_config(task)
task_names.append(cfg)
task_missing = [
task for task in task_list if task not in task_names and "*" not in task
] # we don't want errors if a wildcard ("*") task name was used
......@@ -462,10 +460,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
raise ValueError(
f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
)
config.tasks = task_names
cfg.tasks = task_names
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if config.trust_remote_code:
if cfg.trust_remote_code:
eval_logger.info(
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
)
......@@ -476,7 +474,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
config.model_args["trust_remote_code"] = True
cfg.model_args["trust_remote_code"] = True
(
eval_logger.info(f"Selected Tasks: {task_names}")
if eval_logger.getEffectiveLevel() >= logging.INFO
......@@ -484,66 +482,66 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
)
request_caching_args = request_caching_arg_to_dict(
cache_requests=config.cache_requests
cache_requests=cfg.cache_requests
)
config.request_caching_args = request_caching_args
cfg.request_caching_args = request_caching_args
print(f"CONFIG_AFTER: {config}")
print(f"CONFIG_AFTER: {cfg}")
results = evaluator.simple_evaluate(
model=config.model,
model_args=config.model_args,
tasks=config.tasks,
num_fewshot=config.num_fewshot,
batch_size=config.batch_size,
max_batch_size=config.max_batch_size,
device=config.device,
use_cache=config.use_cache,
cache_requests=config.request_caching_args.get("cache_requests", False),
rewrite_requests_cache=config.request_caching_args.get(
model=cfg.model,
model_args=cfg.model_args,
tasks=cfg.tasks,
num_fewshot=cfg.num_fewshot,
batch_size=cfg.batch_size,
max_batch_size=cfg.max_batch_size,
device=cfg.device,
use_cache=cfg.use_cache,
cache_requests=cfg.request_caching_args.get("cache_requests", False),
rewrite_requests_cache=cfg.request_caching_args.get(
"rewrite_requests_cache", False
),
delete_requests_cache=config.request_caching_args.get(
delete_requests_cache=cfg.request_caching_args.get(
"delete_requests_cache", False
),
limit=config.limit,
samples=config.samples,
check_integrity=config.check_integrity,
write_out=config.write_out,
log_samples=config.log_samples,
limit=cfg.limit,
samples=cfg.samples,
check_integrity=cfg.check_integrity,
write_out=cfg.write_out,
log_samples=cfg.log_samples,
evaluation_tracker=evaluation_tracker,
system_instruction=config.system_instruction,
apply_chat_template=config.apply_chat_template,
fewshot_as_multiturn=config.fewshot_as_multiturn,
gen_kwargs=config.gen_kwargs,
system_instruction=cfg.system_instruction,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
gen_kwargs=cfg.gen_kwargs,
task_manager=task_manager,
verbosity=config.verbosity,
predict_only=config.predict_only,
random_seed=config.seed[0] if config.seed else None,
numpy_random_seed=config.seed[1] if config.seed else None,
torch_random_seed=config.seed[2] if config.seed else None,
fewshot_random_seed=config.seed[3] if config.seed else None,
confirm_run_unsafe_code=config.confirm_run_unsafe_code,
metadata=config.metadata,
verbosity=cfg.verbosity,
predict_only=cfg.predict_only,
random_seed=cfg.seed[0] if cfg.seed else None,
numpy_random_seed=cfg.seed[1] if cfg.seed else None,
torch_random_seed=cfg.seed[2] if cfg.seed else None,
fewshot_random_seed=cfg.seed[3] if cfg.seed else None,
confirm_run_unsafe_code=cfg.confirm_run_unsafe_code,
metadata=cfg.metadata,
)
if results is not None:
if config.log_samples:
if cfg.log_samples:
samples = results.pop("samples")
dumped = json.dumps(
results, indent=2, default=handle_non_serializable, ensure_ascii=False
)
if config.show_config:
if cfg.show_config:
print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
# Add W&B logging
if config.wandb_args:
if cfg.wandb_args:
try:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
if config.log_samples:
if cfg.log_samples:
wandb_logger.log_eval_samples(samples)
except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
......@@ -552,7 +550,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
results=results, samples=samples if args.log_samples else None
)
if config.log_samples:
if cfg.log_samples:
for task_name, _ in results["configs"].items():
evaluation_tracker.save_results_samples(
task_name=task_name, samples=samples[task_name]
......@@ -565,14 +563,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
evaluation_tracker.recreate_metadata_card()
print(
f"{config.model} ({config.model_args}), gen_kwargs: ({config.gen_kwargs}), limit: {config.limit}, num_fewshot: {config.num_fewshot}, "
f"batch_size: {config.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
f"{cfg.model} ({cfg.model_args}), gen_kwargs: ({cfg.gen_kwargs}), limit: {cfg.limit}, num_fewshot: {cfg.num_fewshot}, "
f"batch_size: {cfg.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
)
print(make_table(results))
if "groups" in results:
print(make_table(results, "groups"))
if config.wandb_args:
if cfg.wandb_args:
# Tear down wandb run once all the logging is done.
wandb_logger.run.finish()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment