Commit 61fc5bfd authored by artemorloff's avatar artemorloff
Browse files

pre-commit prettify

parent c1e43393
......@@ -8,22 +8,22 @@ from pathlib import Path
from typing import Union
from lm_eval import evaluator, utils
from lm_eval.api.eval_config import (
EvaluationConfig,
TrackExplicitAction,
TrackExplicitStoreTrue,
)
# from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager
from lm_eval.utils import (
handle_non_serializable,
load_yaml_config,
make_table,
request_caching_arg_to_dict,
# non_default_update,
# parse_namespace,
)
from lm_eval.api.eval_config import (
TrackExplicitAction,
TrackExplicitStoreTrue,
EvaluationConfig,
)
def try_parse_json(value: str) -> Union[str, dict, None]:
......
import argparse
import os
import yaml
from argparse import Namespace
from typing import Any, Dict, Union, Optional
import argparse
from typing import Any, Dict, Optional, Union
import yaml
from pydantic import BaseModel
from lm_eval.utils import simple_parse_args_string
......@@ -21,6 +23,7 @@ class EvaluationConfig(BaseModel):
Simple config container for language-model evaluation.
No content validation here—just holds whatever comes from YAML or CLI.
"""
config: Optional[str]
model: Optional[str]
model_args: Optional[dict]
......@@ -54,7 +57,6 @@ class EvaluationConfig(BaseModel):
metadata: Optional[dict]
request_caching_args: Optional[dict] = None
@staticmethod
def parse_namespace(namespace: argparse.Namespace) -> Dict[str, Any]:
"""
......@@ -90,7 +92,6 @@ class EvaluationConfig(BaseModel):
return config, non_default_args
@staticmethod
def non_default_update(console_dict, local_dict, non_default_args):
"""
......@@ -117,7 +118,6 @@ class EvaluationConfig(BaseModel):
return result_config
@classmethod
def from_cli(cls, namespace: Namespace) -> "EvaluationConfig":
"""
......@@ -142,7 +142,9 @@ class EvaluationConfig(BaseModel):
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {cfg_path}: {e}")
if not isinstance(yaml_data, dict):
raise ValueError(f"YAML root must be a mapping, got {type(yaml_data).__name__}")
raise ValueError(
f"YAML root must be a mapping, got {type(yaml_data).__name__}"
)
config_data.update(yaml_data)
# 3. Override with any CLI args the user explicitly passed
......@@ -153,7 +155,9 @@ class EvaluationConfig(BaseModel):
# config_data[key] = val
print(f"YAML: {config_data}")
print(f"CLI: {args_dict}")
dict_config = EvaluationConfig.non_default_update(args_dict, config_data, explicit_args)
dict_config = EvaluationConfig.non_default_update(
args_dict, config_data, explicit_args
)
# 4. Instantiate the Pydantic model (no further validation here)
return cls(**dict_config)
......
......@@ -4,7 +4,7 @@ import logging
import random
import time
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import torch
......@@ -13,6 +13,7 @@ import lm_eval.api.metrics
import lm_eval.api.registry
import lm_eval.api.task
import lm_eval.models
from lm_eval.api.eval_config import EvaluationConfig
from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import (
consolidate_group_results,
......@@ -34,7 +35,6 @@ from lm_eval.utils import (
setup_logging,
simple_parse_args_string,
)
from lm_eval.api.eval_config import EvaluationConfig
if TYPE_CHECKING:
......@@ -215,7 +215,9 @@ def simple_evaluate(
lm = config.model
if config.use_cache is not None:
eval_logger.info(f"Using cache at {config.use_cache + '_rank' + str(lm.rank) + '.db'}")
eval_logger.info(
f"Using cache at {config.use_cache + '_rank' + str(lm.rank) + '.db'}"
)
lm = lm_eval.api.model.CachingLM(
lm,
config.use_cache
......@@ -249,7 +251,9 @@ def simple_evaluate(
if task_obj.get_config("output_type") == "generate_until":
if config.gen_kwargs is not None:
task_obj.set_config(
key="generation_kwargs", value=config.gen_kwargs, update=True
key="generation_kwargs",
value=config.gen_kwargs,
update=True,
)
eval_logger.info(
f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
......@@ -271,7 +275,7 @@ def simple_evaluate(
)
else:
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {config.num_fewshot}"
)
task_obj.set_config(key="num_fewshot", value=config.num_fewshot)
else:
......@@ -309,7 +313,9 @@ def simple_evaluate(
limit=config.limit,
samples=config.samples,
cache_requests=config.cache_requests,
rewrite_requests_cache=config.request_caching_args.get("rewrite_requests_cache", False),
rewrite_requests_cache=config.request_caching_args.get(
"rewrite_requests_cache", False
),
bootstrap_iters=bootstrap_iters,
write_out=config.write_out,
log_samples=True if config.predict_only else config.log_samples,
......@@ -325,7 +331,9 @@ def simple_evaluate(
if lm.rank == 0:
if isinstance(config.model, str):
model_name = config.model
elif hasattr(config.model, "config") and hasattr(config.model.config, "_name_or_path"):
elif hasattr(config.model, "config") and hasattr(
config.model.config, "_name_or_path"
):
model_name = config.model.config._name_or_path
else:
model_name = type(config.model).__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment