Commit 61fc5bfd authored by artemorloff's avatar artemorloff
Browse files

pre-commit prettify

parent c1e43393
...@@ -8,22 +8,22 @@ from pathlib import Path ...@@ -8,22 +8,22 @@ from pathlib import Path
from typing import Union from typing import Union
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.api.eval_config import (
EvaluationConfig,
TrackExplicitAction,
TrackExplicitStoreTrue,
)
# from lm_eval.evaluator import request_caching_arg_to_dict # from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
from lm_eval.utils import ( from lm_eval.utils import (
handle_non_serializable, handle_non_serializable,
load_yaml_config,
make_table, make_table,
request_caching_arg_to_dict, request_caching_arg_to_dict,
# non_default_update, # non_default_update,
# parse_namespace, # parse_namespace,
) )
from lm_eval.api.eval_config import (
TrackExplicitAction,
TrackExplicitStoreTrue,
EvaluationConfig,
)
def try_parse_json(value: str) -> Union[str, dict, None]: def try_parse_json(value: str) -> Union[str, dict, None]:
......
import argparse
import os import os
import yaml
from argparse import Namespace from argparse import Namespace
from typing import Any, Dict, Union, Optional from typing import Any, Dict, Optional, Union
import argparse
import yaml
from pydantic import BaseModel from pydantic import BaseModel
from lm_eval.utils import simple_parse_args_string from lm_eval.utils import simple_parse_args_string
...@@ -21,6 +23,7 @@ class EvaluationConfig(BaseModel): ...@@ -21,6 +23,7 @@ class EvaluationConfig(BaseModel):
Simple config container for language-model evaluation. Simple config container for language-model evaluation.
No content validation here—just holds whatever comes from YAML or CLI. No content validation here—just holds whatever comes from YAML or CLI.
""" """
config: Optional[str] config: Optional[str]
model: Optional[str] model: Optional[str]
model_args: Optional[dict] model_args: Optional[dict]
...@@ -54,7 +57,6 @@ class EvaluationConfig(BaseModel): ...@@ -54,7 +57,6 @@ class EvaluationConfig(BaseModel):
metadata: Optional[dict] metadata: Optional[dict]
request_caching_args: Optional[dict] = None request_caching_args: Optional[dict] = None
@staticmethod @staticmethod
def parse_namespace(namespace: argparse.Namespace) -> Dict[str, Any]: def parse_namespace(namespace: argparse.Namespace) -> Dict[str, Any]:
""" """
...@@ -90,7 +92,6 @@ class EvaluationConfig(BaseModel): ...@@ -90,7 +92,6 @@ class EvaluationConfig(BaseModel):
return config, non_default_args return config, non_default_args
@staticmethod @staticmethod
def non_default_update(console_dict, local_dict, non_default_args): def non_default_update(console_dict, local_dict, non_default_args):
""" """
...@@ -117,7 +118,6 @@ class EvaluationConfig(BaseModel): ...@@ -117,7 +118,6 @@ class EvaluationConfig(BaseModel):
return result_config return result_config
@classmethod @classmethod
def from_cli(cls, namespace: Namespace) -> "EvaluationConfig": def from_cli(cls, namespace: Namespace) -> "EvaluationConfig":
""" """
...@@ -142,7 +142,9 @@ class EvaluationConfig(BaseModel): ...@@ -142,7 +142,9 @@ class EvaluationConfig(BaseModel):
except yaml.YAMLError as e: except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {cfg_path}: {e}") raise ValueError(f"Invalid YAML in {cfg_path}: {e}")
if not isinstance(yaml_data, dict): if not isinstance(yaml_data, dict):
raise ValueError(f"YAML root must be a mapping, got {type(yaml_data).__name__}") raise ValueError(
f"YAML root must be a mapping, got {type(yaml_data).__name__}"
)
config_data.update(yaml_data) config_data.update(yaml_data)
# 3. Override with any CLI args the user explicitly passed # 3. Override with any CLI args the user explicitly passed
...@@ -153,7 +155,9 @@ class EvaluationConfig(BaseModel): ...@@ -153,7 +155,9 @@ class EvaluationConfig(BaseModel):
# config_data[key] = val # config_data[key] = val
print(f"YAML: {config_data}") print(f"YAML: {config_data}")
print(f"CLI: {args_dict}") print(f"CLI: {args_dict}")
dict_config = EvaluationConfig.non_default_update(args_dict, config_data, explicit_args) dict_config = EvaluationConfig.non_default_update(
args_dict, config_data, explicit_args
)
# 4. Instantiate the Pydantic model (no further validation here) # 4. Instantiate the Pydantic model (no further validation here)
return cls(**dict_config) return cls(**dict_config)
......
...@@ -4,7 +4,7 @@ import logging ...@@ -4,7 +4,7 @@ import logging
import random import random
import time import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -13,6 +13,7 @@ import lm_eval.api.metrics ...@@ -13,6 +13,7 @@ import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
import lm_eval.api.task import lm_eval.api.task
import lm_eval.models import lm_eval.models
from lm_eval.api.eval_config import EvaluationConfig
from lm_eval.caching.cache import delete_cache from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import ( from lm_eval.evaluator_utils import (
consolidate_group_results, consolidate_group_results,
...@@ -34,7 +35,6 @@ from lm_eval.utils import ( ...@@ -34,7 +35,6 @@ from lm_eval.utils import (
setup_logging, setup_logging,
simple_parse_args_string, simple_parse_args_string,
) )
from lm_eval.api.eval_config import EvaluationConfig
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -215,7 +215,9 @@ def simple_evaluate( ...@@ -215,7 +215,9 @@ def simple_evaluate(
lm = config.model lm = config.model
if config.use_cache is not None: if config.use_cache is not None:
eval_logger.info(f"Using cache at {config.use_cache + '_rank' + str(lm.rank) + '.db'}") eval_logger.info(
f"Using cache at {config.use_cache + '_rank' + str(lm.rank) + '.db'}"
)
lm = lm_eval.api.model.CachingLM( lm = lm_eval.api.model.CachingLM(
lm, lm,
config.use_cache config.use_cache
...@@ -249,7 +251,9 @@ def simple_evaluate( ...@@ -249,7 +251,9 @@ def simple_evaluate(
if task_obj.get_config("output_type") == "generate_until": if task_obj.get_config("output_type") == "generate_until":
if config.gen_kwargs is not None: if config.gen_kwargs is not None:
task_obj.set_config( task_obj.set_config(
key="generation_kwargs", value=config.gen_kwargs, update=True key="generation_kwargs",
value=config.gen_kwargs,
update=True,
) )
eval_logger.info( eval_logger.info(
f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}" f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
...@@ -271,7 +275,7 @@ def simple_evaluate( ...@@ -271,7 +275,7 @@ def simple_evaluate(
) )
else: else:
eval_logger.warning( eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {config.num_fewshot}"
) )
task_obj.set_config(key="num_fewshot", value=config.num_fewshot) task_obj.set_config(key="num_fewshot", value=config.num_fewshot)
else: else:
...@@ -309,7 +313,9 @@ def simple_evaluate( ...@@ -309,7 +313,9 @@ def simple_evaluate(
limit=config.limit, limit=config.limit,
samples=config.samples, samples=config.samples,
cache_requests=config.cache_requests, cache_requests=config.cache_requests,
rewrite_requests_cache=config.request_caching_args.get("rewrite_requests_cache", False), rewrite_requests_cache=config.request_caching_args.get(
"rewrite_requests_cache", False
),
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
write_out=config.write_out, write_out=config.write_out,
log_samples=True if config.predict_only else config.log_samples, log_samples=True if config.predict_only else config.log_samples,
...@@ -325,7 +331,9 @@ def simple_evaluate( ...@@ -325,7 +331,9 @@ def simple_evaluate(
if lm.rank == 0: if lm.rank == 0:
if isinstance(config.model, str): if isinstance(config.model, str):
model_name = config.model model_name = config.model
elif hasattr(config.model, "config") and hasattr(config.model.config, "_name_or_path"): elif hasattr(config.model, "config") and hasattr(
config.model.config, "_name_or_path"
):
model_name = config.model.config._name_or_path model_name = config.model.config._name_or_path
else: else:
model_name = type(config.model).__name__ model_name = type(config.model).__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment