Commit 4d3387f6 authored by Baber's avatar Baber
Browse files

cleanup

parent 79a22a11
...@@ -53,6 +53,6 @@ class FilterEnsemble: ...@@ -53,6 +53,6 @@ class FilterEnsemble:
resps = f().apply(resps, docs) resps = f().apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances. # add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name. # has a key ` self.name `: each FilterEnsemble applied in a given run should use a different name.
for inst, resp in zip(instances, resps): for inst, resp in zip(instances, resps):
inst.filtered_resps[self.name] = resp inst.filtered_resps[self.name] = resp
...@@ -220,8 +220,8 @@ class Registry(Generic[T]): ...@@ -220,8 +220,8 @@ class Registry(Generic[T]):
>>> model_registry.register("lazy-name", lazy="mymodule:MyModel") >>> model_registry.register("lazy-name", lazy="mymodule:MyModel")
Raises: Raises:
ValueError: If alias already registered with different target ValueError: If alias is already registered with a different target
TypeError: If object doesn't inherit from base_cls (when specified) TypeError: If an object doesn't inherit from base_cls (when specified)
""" """
def _store(alias: str, target: T | Placeholder) -> None: def _store(alias: str, target: T | Placeholder) -> None:
...@@ -229,21 +229,27 @@ class Registry(Generic[T]): ...@@ -229,21 +229,27 @@ class Registry(Generic[T]):
# collision handling ------------------------------------------ # collision handling ------------------------------------------
if current is not None and current != target: if current is not None and current != target:
# allow placeholder → real object upgrade # allow placeholder → real object upgrade
if isinstance(current, str) and isinstance(target, type): # mod, _, cls = current.partition(":")
# mod, _, cls = current.partition(":") if (
if current == f"{target.__module__}:{target.__name__}": isinstance(current, str)
self._objs[alias] = target and isinstance(target, type)
return and current == f"{target.__module__}:{target.__name__}"
):
self._objs[alias] = target
return
raise ValueError( raise ValueError(
f"{self._name!r} alias '{alias}' already registered (" f"{self._name!r} alias '{alias}' already registered ("
f"existing={current}, new={target})" f"existing={current}, new={target})"
) )
# type check for concrete classes ---------------------------------------------- # type check for concrete classes ----------------------------------------------
if self._base_cls is not None and isinstance(target, type): if (
if not issubclass(target, self._base_cls): # type: ignore[arg-type] self._base_cls is not None
raise TypeError( and isinstance(target, type)
f"{target} must inherit from {self._base_cls} to be a {self._name}" and not issubclass(target, self._base_cls)
) ):
raise TypeError(
f"{target} must inherit from {self._base_cls} to be a {self._name}"
)
self._objs[alias] = target self._objs[alias] = target
def decorator(obj: T) -> T: # type: ignore[valid-type] def decorator(obj: T) -> T: # type: ignore[valid-type]
...@@ -409,9 +415,7 @@ class MetricSpec: ...@@ -409,9 +415,7 @@ class MetricSpec:
from lm_eval.api.model import LM # noqa: E402 from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = cast( model_registry = cast(Registry[type[LM]], Registry("model", base_cls=LM))
Registry[type[LM]], Registry("model", base_cls=LM)
)
task_registry: Registry[Callable[..., Any]] = Registry("task") task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric") metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry( metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
...@@ -457,7 +461,7 @@ def register_metric(**kw): ...@@ -457,7 +461,7 @@ def register_metric(**kw):
then registers it in the metric registry. then registers it in the metric registry.
Args: Args:
**kw: Keyword arguments including: **kw: Keyword arguments including
- metric: Name to register the metric under (required) - metric: Name to register the metric under (required)
- aggregation: Name of aggregation function in metric_agg_registry - aggregation: Name of aggregation function in metric_agg_registry
- higher_is_better: Whether higher scores are better (default: True) - higher_is_better: Whether higher scores are better (default: True)
...@@ -512,7 +516,7 @@ def get_metric(name, hf_evaluate_metric=False): ...@@ -512,7 +516,7 @@ def get_metric(name, hf_evaluate_metric=False):
The metric's compute function The metric's compute function
Raises: Raises:
KeyError: If metric not found in registry or HF evaluate KeyError: If a metric is not found in registry or HF evaluate
""" """
try: try:
spec = metric_registry.get(name) spec = metric_registry.get(name)
...@@ -529,7 +533,7 @@ def get_metric(name, hf_evaluate_metric=False): ...@@ -529,7 +533,7 @@ def get_metric(name, hf_evaluate_metric=False):
return hf.load(name).compute # type: ignore[attr-defined] return hf.load(name).compute # type: ignore[attr-defined]
except Exception: except Exception:
raise KeyError(f"Metric '{name}' not found anywhere") raise KeyError(f"Metric '{name}' not found anywhere") from None
register_metric_aggregation = metric_agg_registry.register register_metric_aggregation = metric_agg_registry.register
......
...@@ -4,7 +4,7 @@ import textwrap ...@@ -4,7 +4,7 @@ import textwrap
from argparse import Namespace from argparse import Namespace
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import yaml import yaml
...@@ -214,7 +214,7 @@ class EvaluatorConfig: ...@@ -214,7 +214,7 @@ class EvaluatorConfig:
# Parse string arguments that should be dictionaries # Parse string arguments that should be dictionaries
config = cls._parse_dict_args(config) config = cls._parse_dict_args(config)
# Create instance and validate # Create an instance and validate
instance = cls(**config) instance = cls(**config)
if used_config: if used_config:
print(textwrap.dedent(f"""{instance}""")) print(textwrap.dedent(f"""{instance}"""))
...@@ -238,7 +238,7 @@ class EvaluatorConfig: ...@@ -238,7 +238,7 @@ class EvaluatorConfig:
return instance return instance
@staticmethod @staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]: def _parse_dict_args(config: dict[str, Any]) -> dict[str, Any]:
"""Parse string arguments that should be dictionaries.""" """Parse string arguments that should be dictionaries."""
for key in config: for key in config:
if key in DICT_KEYS and isinstance(config[key], str): if key in DICT_KEYS and isinstance(config[key], str):
...@@ -246,7 +246,7 @@ class EvaluatorConfig: ...@@ -246,7 +246,7 @@ class EvaluatorConfig:
return config return config
@staticmethod @staticmethod
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]: def load_yaml_config(config_path: Union[str, Path]) -> dict[str, Any]:
"""Load and validate YAML config file.""" """Load and validate YAML config file."""
config_file = ( config_file = (
Path(config_path) if not isinstance(config_path, Path) else config_path Path(config_path) if not isinstance(config_path, Path) else config_path
...@@ -257,9 +257,9 @@ class EvaluatorConfig: ...@@ -257,9 +257,9 @@ class EvaluatorConfig:
try: try:
yaml_data = yaml.safe_load(config_file.read_text()) yaml_data = yaml.safe_load(config_file.read_text())
except yaml.YAMLError as e: except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {config_path}: {e}") raise ValueError(f"Invalid YAML in {config_path}: {e}") from e
except (OSError, UnicodeDecodeError) as e: except (OSError, UnicodeDecodeError) as e:
raise ValueError(f"Could not read config file {config_path}: {e}") raise ValueError(f"Could not read config file {config_path}: {e}") from e
if not isinstance(yaml_data, dict): if not isinstance(yaml_data, dict):
raise ValueError( raise ValueError(
...@@ -307,7 +307,7 @@ class EvaluatorConfig: ...@@ -307,7 +307,7 @@ class EvaluatorConfig:
raise ValueError("Need to specify task to evaluate.") raise ValueError("Need to specify task to evaluate.")
def _process_arguments(self) -> None: def _process_arguments(self) -> None:
"""Process samples argument - load from file if needed.""" """Process samples argument - load from a file if needed."""
if self.samples: if self.samples:
if isinstance(self.samples, dict): if isinstance(self.samples, dict):
self.samples = self.samples self.samples = self.samples
...@@ -328,7 +328,6 @@ class EvaluatorConfig: ...@@ -328,7 +328,6 @@ class EvaluatorConfig:
def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager": def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
"""Process and validate tasks, return resolved task names.""" """Process and validate tasks, return resolved task names."""
from lm_eval import utils
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
# if metadata manually passed use that: # if metadata manually passed use that:
...@@ -365,7 +364,7 @@ class EvaluatorConfig: ...@@ -365,7 +364,7 @@ class EvaluatorConfig:
return task_manager return task_manager
def _set_trust_remote_code(self) -> None: def _set_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled.""" """Apply the trust_remote_code setting if enabled."""
if self.trust_remote_code: if self.trust_remote_code:
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally, # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our # because it's already been determined based on the prior env var before launching our
......
...@@ -21,7 +21,7 @@ def serialize_callable( ...@@ -21,7 +21,7 @@ def serialize_callable(
return value return value
else: else:
try: try:
return getsource(value) return getsource(value) # type: ignore
except (TypeError, OSError): except (TypeError, OSError):
return str(value) return str(value)
......
...@@ -60,10 +60,10 @@ def _load_module_with_cache(module_path: Path) -> Any: ...@@ -60,10 +60,10 @@ def _load_module_with_cache(module_path: Path) -> Any:
module_parts = relative_path.replace(".py", "").replace("/", ".") module_parts = relative_path.replace(".py", "").replace("/", ".")
module_name = f"lm_eval.tasks.{module_parts}" module_name = f"lm_eval.tasks.{module_parts}"
else: else:
# Fallback to full path if pattern not found # Fallback to a full path if a pattern not found
module_name = str(module_path.with_suffix("")) module_name = str(module_path.with_suffix(""))
else: else:
# External module - use full path without extension # External module - use a full path without extension
module_name = str(module_path.with_suffix("")) module_name = str(module_path.with_suffix(""))
# Check if we need to reload the module # Check if we need to reload the module
...@@ -84,7 +84,7 @@ def _load_module_with_cache(module_path: Path) -> Any: ...@@ -84,7 +84,7 @@ def _load_module_with_cache(module_path: Path) -> Any:
raise ImportError(f"Cannot load module from {module_path}") from None raise ImportError(f"Cannot load module from {module_path}") from None
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
# Store mtime for future checks # Store mtime for future checks
module.__mtime__ = module_path.stat().st_mtime_ns module.__mtime__ = module_path.stat().st_mtime_ns # type: ignore
spec.loader.exec_module(module) # type: ignore[arg-type] spec.loader.exec_module(module) # type: ignore[arg-type]
sys.modules[module_name] = module sys.modules[module_name] = module
return module return module
......
...@@ -32,9 +32,9 @@ class TaskFactory: ...@@ -32,9 +32,9 @@ class TaskFactory:
registry: Mapping[str, Entry], registry: Mapping[str, Entry],
): ):
""" """
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object • entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks) • entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion) • entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
""" """
if entry.kind is Kind.TAG: if entry.kind is Kind.TAG:
return self._build_tag(entry, overrides, registry) return self._build_tag(entry, overrides, registry)
...@@ -121,4 +121,4 @@ class TaskFactory: ...@@ -121,4 +121,4 @@ class TaskFactory:
def _ctor_accepts_config(cls) -> bool: def _ctor_accepts_config(cls) -> bool:
init = getattr(cls, "__init__", None) init = getattr(cls, "__init__", None)
return init and "config" in inspect.signature(init).parameters return bool(init and "config" in inspect.signature(init).parameters)
...@@ -61,8 +61,8 @@ class TaskManager: ...@@ -61,8 +61,8 @@ class TaskManager:
def load_spec(self, spec: str | dict[str, Any]): def load_spec(self, spec: str | dict[str, Any]):
"""Spec can be: """Spec can be:
• str task / group / tag name (registered) • str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5} • dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
""" """
if isinstance(spec, str): if isinstance(spec, str):
entry = self._entry(spec) entry = self._entry(spec)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment