Commit 93b2ab37 authored by Baber's avatar Baber
Browse files

refactor registry

parent de496b80
...@@ -4,8 +4,8 @@ import os ...@@ -4,8 +4,8 @@ import os
import random import random
import re import re
import string import string
from collections.abc import Iterable from collections.abc import Iterable, Sequence
from typing import Callable, List, Optional, Sequence, TypeVar from typing import Callable, List, Optional, TypeVar
import numpy as np import numpy as np
import sacrebleu import sacrebleu
......
import logging from __future__ import annotations
from typing import Callable, Dict, Union
import importlib
import inspect
import threading
from collections.abc import Iterable, Mapping, MutableMapping
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
from typing import (
Any,
Callable,
Generic,
TypeVar,
)
try: # Python≥3.10
import importlib.metadata as md
except ImportError: # pragma: no cover - fallback for 3.8/3.9 runtimes
import importlib_metadata as md # type: ignore
__all__ = [
"Registry",
"MetricSpec",
# concrete registries
"model_registry",
"task_registry",
"metric_registry",
"metric_agg_registry",
"higher_is_better_registry",
"filter_registry",
# helper
"freeze_all",
# Legacy compatibility
"DEFAULT_METRIC_REGISTRY",
"AGGREGATION_REGISTRY",
"register_model",
"get_model",
"register_task",
"get_task",
"register_metric",
"get_metric",
"register_metric_aggregation",
"get_metric_aggregation",
"register_higher_is_better",
"is_higher_better",
"register_filter",
"get_filter",
"register_aggregation",
"get_aggregation",
"MODEL_REGISTRY",
"TASK_REGISTRY",
"METRIC_REGISTRY",
"METRIC_AGGREGATION_REGISTRY",
"HIGHER_IS_BETTER_REGISTRY",
"FILTER_REGISTRY",
]
T = TypeVar("T")
# ────────────────────────────────────────────────────────────────────────
# Generic Registry
# ────────────────────────────────────────────────────────────────────────
class Registry(Generic[T]):
"""Name -> object mapping with decorator helpers and **lazy import** support."""
#: The underlying mutable mapping (might turn into MappingProxy on freeze)
_objects: MutableMapping[str, T | str | md.EntryPoint]
def __init__(
self,
name: str,
*,
base_cls: type[T] | None = None,
store: MutableMapping[str, T | str | md.EntryPoint] | None = None,
validator: Callable[[T], bool] | None = None,
) -> None:
self._name: str = name
self._base_cls: type[T] | None = base_cls
self._objects = store if store is not None else {}
self._metadata: dict[
str, dict[str, Any]
] = {} # Store metadata for each registered item
self._validator = validator # Custom validation function
self._lock = threading.RLock()
# ------------------------------------------------------------------
# Registration helpers (decorator or direct call)
# ------------------------------------------------------------------
def register(
self,
*aliases: str,
lazy: str | md.EntryPoint | None = None,
metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
def _do_register(target: T | str | md.EntryPoint) -> None:
if not aliases:
_aliases = (getattr(target, "__name__", str(target)),)
else:
_aliases = aliases
with self._lock:
for alias in _aliases:
if alias in self._objects:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing = self._objects[alias]
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
)
# Eager type check only when we have a concrete class
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} "
f"to be registered as a {self._name}"
)
self._objects[alias] = target
# Store metadata if provided
if metadata:
self._metadata[alias] = metadata
# ─── decorator path ───
def decorator(obj: T) -> T: # type: ignore[valid-type]
_do_register(obj)
return obj
# ─── direct‑call path with lazy placeholder ───
if lazy is not None:
_do_register(lazy)
return lambda x: x # no‑op decorator for accidental use
return decorator
def register_bulk(
self,
items: dict[str, T | str | md.EntryPoint],
metadata: dict[str, dict[str, Any]] | None = None,
) -> None:
"""Register multiple items at once.
Args:
items: Dictionary mapping aliases to objects/lazy paths
metadata: Optional dictionary mapping aliases to metadata
"""
with self._lock:
for alias, target in items.items():
if alias in self._objects:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing = self._objects[alias]
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
)
# Eager type check only when we have a concrete class
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} "
f"to be registered as a {self._name}"
)
self._objects[alias] = target
# Store metadata if provided
if metadata and alias in metadata:
self._metadata[alias] = metadata[alias]
# ------------------------------------------------------------------
# Lookup & materialisation
# ------------------------------------------------------------------
@lru_cache(maxsize=256) # Bounded cache to prevent memory growth
def _materialise(self, target: T | str | md.EntryPoint) -> T:
"""Import *target* if it is a dotted‑path string or EntryPoint."""
if isinstance(target, str):
mod, _, obj_name = target.partition(":")
if not _:
raise ValueError(
f"Lazy path '{target}' must be in 'module:object' form"
)
module = importlib.import_module(mod)
return getattr(module, obj_name)
if isinstance(target, md.EntryPoint):
return target.load()
return target # concrete already
def get(self, alias: str) -> T:
with self._lock:
try:
target = self._objects[alias]
except KeyError as exc:
raise KeyError(
f"Unknown {self._name} '{alias}'. Available: "
f"{', '.join(self._objects)}"
) from exc
# Only materialize if it's a string or EntryPoint (lazy placeholder)
if isinstance(target, (str, md.EntryPoint)):
concrete: T = self._materialise(target)
# First‑touch: swap placeholder with concrete obj for future calls
if concrete is not target:
self._objects[alias] = concrete
else:
# Already materialized, just return it
concrete = target
# Late type check (for placeholders)
if self._base_cls is not None and not issubclass(concrete, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{concrete} does not inherit from {self._base_cls} "
f"(registered under alias '{alias}')"
)
import evaluate as hf_evaluate # Custom validation
if self._validator is not None and not self._validator(concrete):
raise ValueError(
f"{concrete} failed custom validation for {self._name} registry "
f"(registered under alias '{alias}')"
)
from lm_eval.api.model import LM return concrete
# Mapping / dunder helpers -------------------------------------------------
eval_logger = logging.getLogger(__name__) def __getitem__(self, alias: str) -> T: # noqa
return self.get(alias)
MODEL_REGISTRY = {} def __iter__(self): # noqa
return iter(self._objects)
def __len__(self) -> int: # noqa
return len(self._objects)
def register_model(*names): def items(self): # noqa
# either pass a list or a single alias. return self._objects.items()
# function receives them as a tuple of strings
def decorate(cls): # Introspection -----------------------------------------------------------
for name in names:
assert issubclass(cls, LM), (
f"Model '{name}' ({cls.__name__}) must extend LM class"
)
assert name not in MODEL_REGISTRY, ( def origin(self, alias: str) -> str | None:
f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead." obj = self._objects.get(alias)
) try:
if isinstance(obj, str) or isinstance(obj, md.EntryPoint):
return None # placeholder - unknown until imported
file = inspect.getfile(obj) # type: ignore[arg-type]
line = inspect.getsourcelines(obj)[1] # type: ignore[arg-type]
return f"{file}:{line}"
except (
TypeError,
OSError,
AttributeError,
): # pragma: no cover - best-effort only
# TypeError: object not suitable for inspect
# OSError: file not found or accessible
# AttributeError: object lacks expected attributes
return None
MODEL_REGISTRY[name] = cls def get_metadata(self, alias: str) -> dict[str, Any] | None:
return cls """Get metadata for a registered item."""
with self._lock:
return self._metadata.get(alias)
return decorate # Mutability --------------------------------------------------------------
def freeze(self):
"""Make the registry *names* immutable (materialisation still works)."""
with self._lock:
if isinstance(self._objects, MappingProxyType):
return # already frozen
self._objects = MappingProxyType(dict(self._objects)) # type: ignore[assignment]
def get_model(model_name): def clear(self):
try: """Clear the registry (useful for tests). Cannot be called on frozen registries."""
return MODEL_REGISTRY[model_name] with self._lock:
except KeyError: if isinstance(self._objects, MappingProxyType):
raise ValueError( raise RuntimeError("Cannot clear a frozen registry")
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}" self._objects.clear()
) self._metadata.clear()
self._materialise.cache_clear() # type: ignore[attr-defined] # Added by lru_cache
TASK_REGISTRY = {} # ────────────────────────────────────────────────────────────────────────
GROUP_REGISTRY = {} # Structured objects stored in registries
ALL_TASKS = set() # ────────────────────────────────────────────────────────────────────────
func2task_index = {}
def register_task(name): @dataclass(frozen=True)
def decorate(fn): class MetricSpec:
assert name not in TASK_REGISTRY, ( """Bundle compute fn, aggregator, and *higher‑is‑better* flag."""
f"task named '{name}' conflicts with existing registered task!"
)
TASK_REGISTRY[name] = fn compute: Callable[[Any, Any], Any]
ALL_TASKS.add(name) aggregate: Callable[[Iterable[Any]], Mapping[str, float]]
func2task_index[fn.__name__] = name higher_is_better: bool = True
return fn output_type: str | None = None # e.g., "probability", "string", "numeric"
requires: list[str] | None = None # Dependencies on other metrics/data
return decorate
# ────────────────────────────────────────────────────────────────────────
# Concrete registries used by lm_eval
# ────────────────────────────────────────────────────────────────────────
def register_group(name): from lm_eval.api.model import LM # noqa: E402
def decorate(fn):
func_name = func2task_index[fn.__name__]
if name in GROUP_REGISTRY:
GROUP_REGISTRY[name].append(func_name)
else:
GROUP_REGISTRY[name] = [func_name]
ALL_TASKS.add(name)
return fn
return decorate
OUTPUT_TYPE_REGISTRY = {} model_registry: Registry[type[LM]] = Registry("model", base_cls=LM)
METRIC_REGISTRY = {} task_registry: Registry[Callable[..., Any]] = Registry("task")
METRIC_AGGREGATION_REGISTRY = {} metric_registry: Registry[MetricSpec] = Registry("metric")
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {} metric_agg_registry: Registry[Callable[[Iterable[Any]], Mapping[str, float]]] = (
HIGHER_IS_BETTER_REGISTRY = {} Registry("metric aggregation")
FILTER_REGISTRY = {} )
higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag")
filter_registry: Registry[Callable] = Registry("filter")
# Default metric registry for output types
DEFAULT_METRIC_REGISTRY = { DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [ "loglikelihood": [
"perplexity", "perplexity",
...@@ -90,107 +347,194 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -90,107 +347,194 @@ DEFAULT_METRIC_REGISTRY = {
"generate_until": ["exact_match"], "generate_until": ["exact_match"],
} }
# Aggregation registry (will be populated by register_aggregation)
AGGREGATION_REGISTRY: dict[str, Callable] = {}
# ────────────────────────────────────────────────────────────────────────
# Public helper aliases (legacy API)
# ────────────────────────────────────────────────────────────────────────
register_model = model_registry.register
get_model = model_registry.get
register_task = task_registry.register
get_task = task_registry.get
# Special handling for metric registration which uses different API
def register_metric(**kwargs):
"""Register a metric with metadata.
Compatible with old registry API that used keyword arguments.
"""
def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn): def decorate(fn):
assert "metric" in args metric_name = kwargs.get("metric")
name = args["metric"] if not metric_name:
raise ValueError("metric name is required")
for key, registry in [
("metric", METRIC_REGISTRY), # Create MetricSpec with the function and metadata
("higher_is_better", HIGHER_IS_BETTER_REGISTRY), spec = MetricSpec(
("aggregation", METRIC_AGGREGATION_REGISTRY), compute=fn,
]: aggregate=lambda x: {}, # Default aggregation returns empty dict
if key in args: higher_is_better=kwargs.get("higher_is_better", True),
value = args[key] output_type=kwargs.get("output_type"),
assert value not in registry, ( requires=kwargs.get("requires"),
f"{key} named '{value}' conflicts with existing registered {key}!" )
# Register in metric registry
metric_registry._objects[metric_name] = spec
# Also handle aggregation if specified
if "aggregation" in kwargs:
agg_name = kwargs["aggregation"]
# Try to get aggregation from AGGREGATION_REGISTRY
if agg_name in AGGREGATION_REGISTRY:
spec = MetricSpec(
compute=fn,
aggregate=AGGREGATION_REGISTRY[agg_name],
higher_is_better=kwargs.get("higher_is_better", True),
output_type=kwargs.get("output_type"),
requires=kwargs.get("requires"),
) )
metric_registry._objects[metric_name] = spec
if key == "metric": # Handle higher_is_better registry
registry[name] = fn if "higher_is_better" in kwargs:
elif key == "aggregation": higher_is_better_registry._objects[metric_name] = kwargs["higher_is_better"]
registry[name] = AGGREGATION_REGISTRY[value]
else:
registry[name] = value
return fn return fn
return decorate return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Callable: def get_metric(name: str, hf_evaluate_metric=False):
"""Get a metric by name, with fallback to HF evaluate."""
if not hf_evaluate_metric: if not hf_evaluate_metric:
if name in METRIC_REGISTRY: try:
return METRIC_REGISTRY[name] spec = metric_registry.get(name)
else: if isinstance(spec, MetricSpec):
eval_logger.warning( return spec.compute
return spec
except KeyError:
import logging
logging.getLogger(__name__).warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..." f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
) )
# Fallback to HF evaluate
try: try:
import evaluate as hf_evaluate
metric_object = hf_evaluate.load(name) metric_object = hf_evaluate.load(name)
return metric_object.compute return metric_object.compute
except Exception: except Exception:
eval_logger.error( import logging
logging.getLogger(__name__).error(
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric", f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
) )
return None
def register_aggregation(name: str): register_metric_aggregation = metric_agg_registry.register
def decorate(fn):
assert name not in AGGREGATION_REGISTRY, (
f"aggregation named '{name}' conflicts with existing registered aggregation!"
)
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate def get_metric_aggregation(metric_name: str):
"""Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation)
if metric_name in metric_registry._objects:
metric_spec = metric_registry._objects[metric_name]
if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate:
return metric_spec.aggregate
# Fall back to metric_agg_registry (for standalone aggregations)
if metric_name in metric_agg_registry._objects:
return metric_agg_registry._objects[metric_name]
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: # If not found, raise error
try: raise KeyError(
return AGGREGATION_REGISTRY[name] f"Unknown metric aggregation '{metric_name}'. Available: {list(AGGREGATION_REGISTRY.keys())}"
except KeyError: )
eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: register_higher_is_better = higher_is_better_registry.register
try: is_higher_better = higher_is_better_registry.get
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
register_filter = filter_registry.register
get_filter = filter_registry.get
def is_higher_better(metric_name) -> bool:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!"
)
# Special handling for AGGREGATION_REGISTRY which works differently
def register_filter(name): def register_aggregation(name: str):
def decorate(cls): def decorate(fn):
if name in FILTER_REGISTRY: if name in AGGREGATION_REGISTRY:
eval_logger.info( raise ValueError(
f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}" f"aggregation named '{name}' conflicts with existing registered aggregation!"
) )
FILTER_REGISTRY[name] = cls AGGREGATION_REGISTRY[name] = fn
return cls return fn
return decorate return decorate
def get_filter(filter_name: Union[str, Callable]) -> Callable: def get_aggregation(name: str) -> Callable[[], dict[str, Callable]]:
try: try:
return FILTER_REGISTRY[filter_name] return AGGREGATION_REGISTRY[name]
except KeyError as e: except KeyError:
if callable(filter_name): import logging
return filter_name
else: logging.getLogger(__name__).warning(
eval_logger.warning(f"filter `{filter_name}` is not registered!") f"{name} not a registered aggregation metric!"
raise e )
return None
# ────────────────────────────────────────────────────────────────────────
# Optional PyPI entry‑point discovery - uncomment if desired
# ────────────────────────────────────────────────────────────────────────
# for _group, _reg in {
# "lm_eval.models": model_registry,
# "lm_eval.tasks": task_registry,
# "lm_eval.metrics": metric_registry,
# }.items():
# for _ep in md.entry_points(group=_group):
# _reg.register(_ep.name, lazy=_ep)
# ────────────────────────────────────────────────────────────────────────
# Convenience
# ────────────────────────────────────────────────────────────────────────
def freeze_all() -> None: # pragma: no cover
"""Freeze every global registry (idempotent)."""
for _reg in (
model_registry,
task_registry,
metric_registry,
metric_agg_registry,
higher_is_better_registry,
filter_registry,
):
_reg.freeze()
# ────────────────────────────────────────────────────────────────────────
# Backwards‑compatibility read‑only globals
# ────────────────────────────────────────────────────────────────────────
MODEL_REGISTRY: Mapping[str, type[LM]] = MappingProxyType(model_registry._objects) # type: ignore[attr-defined]
TASK_REGISTRY: Mapping[str, Callable[..., Any]] = MappingProxyType(
task_registry._objects
) # type: ignore[attr-defined]
METRIC_REGISTRY: Mapping[str, MetricSpec] = MappingProxyType(metric_registry._objects) # type: ignore[attr-defined]
METRIC_AGGREGATION_REGISTRY: Mapping[str, Callable] = MappingProxyType(
metric_agg_registry._objects
) # type: ignore[attr-defined]
HIGHER_IS_BETTER_REGISTRY: Mapping[str, bool] = MappingProxyType(
higher_is_better_registry._objects
) # type: ignore[attr-defined]
FILTER_REGISTRY: Mapping[str, Callable] = MappingProxyType(filter_registry._objects) # type: ignore[attr-defined]
...@@ -3,18 +3,15 @@ import ast ...@@ -3,18 +3,15 @@ import ast
import logging import logging
import random import random
import re import re
from collections.abc import Callable from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import ( from typing import (
Any, Any,
Dict, Dict,
Iterable,
Iterator,
List, List,
Literal, Literal,
Mapping,
Optional, Optional,
Tuple, Tuple,
Union, Union,
...@@ -1774,7 +1771,7 @@ class MultipleChoiceTask(Task): ...@@ -1774,7 +1771,7 @@ class MultipleChoiceTask(Task):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, f" {choice}"),
idx=i, idx=i,
**kwargs, **kwargs,
) )
......
from . import ( # Models are now lazily loaded via the registry system
anthropic_llms, # No need to import them all at once
api_models,
dummy, # Define model mappings for lazy registration
gguf, MODEL_MAPPING = {
hf_audiolm, "anthropic-completions": "lm_eval.models.anthropic_llms:AnthropicLM",
hf_steered, "anthropic-chat": "lm_eval.models.anthropic_llms:AnthropicChatLM",
hf_vlms, "anthropic-chat-completions": "lm_eval.models.anthropic_llms:AnthropicCompletionsLM",
huggingface, "local-completions": "lm_eval.models.openai_completions:LocalCompletionsAPI",
ibm_watsonx_ai, "local-chat-completions": "lm_eval.models.openai_completions:LocalChatCompletion",
mamba_lm, "openai-completions": "lm_eval.models.openai_completions:OpenAICompletionsAPI",
nemo_lm, "openai-chat-completions": "lm_eval.models.openai_completions:OpenAIChatCompletion",
neuron_optimum, "dummy": "lm_eval.models.dummy:DummyLM",
openai_completions, "gguf": "lm_eval.models.gguf:GGUFLM",
optimum_ipex, "ggml": "lm_eval.models.gguf:GGUFLM",
optimum_lm, "hf-audiolm-qwen": "lm_eval.models.hf_audiolm:HFAudioLM",
sglang_causallms, "steered": "lm_eval.models.hf_steered:SteeredHF",
sglang_generate_API, "hf-multimodal": "lm_eval.models.hf_vlms:HFMultimodalLM",
textsynth, "hf-auto": "lm_eval.models.huggingface:HFLM",
vllm_causallms, "hf": "lm_eval.models.huggingface:HFLM",
vllm_vlms, "huggingface": "lm_eval.models.huggingface:HFLM",
) "watsonx_llm": "lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI",
"mamba_ssm": "lm_eval.models.mamba_lm:MambaLMWrapper",
"nemo_lm": "lm_eval.models.nemo_lm:NeMoLM",
# TODO: implement __all__ "neuronx": "lm_eval.models.neuron_optimum:NeuronModelForCausalLM",
"ipex": "lm_eval.models.optimum_ipex:IPEXForCausalLM",
"openvino": "lm_eval.models.optimum_lm:OptimumLM",
"sglang": "lm_eval.models.sglang_causallms:SGLANG",
"sglang-generate": "lm_eval.models.sglang_generate_API:SGAPI",
"textsynth": "lm_eval.models.textsynth:TextSynthLM",
"vllm": "lm_eval.models.vllm_causallms:VLLM",
"vllm-vlm": "lm_eval.models.vllm_vlms:VLLM_VLM",
}
# Register all models lazily
def _register_all_models():
"""Register all known models lazily in the registry."""
from lm_eval.api.registry import model_registry
for name, path in MODEL_MAPPING.items():
# Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry:
# Call register with the lazy parameter, returns a decorator
model_registry.register(name, lazy=path)(None)
# Call registration on module import
_register_all_models()
__all__ = ["MODEL_MAPPING"]
try: try:
......
from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
from peft.peft_model import PeftModel from peft.peft_model import PeftModel
......
...@@ -3,7 +3,7 @@ import json ...@@ -3,7 +3,7 @@ import json
import logging import logging
import os import os
import warnings import warnings
from functools import lru_cache from functools import cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm from tqdm import tqdm
...@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None: ...@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise ValueError(error_msg) raise ValueError(error_msg)
@lru_cache(maxsize=None) @cache
def get_watsonx_credentials() -> Dict[str, str]: def get_watsonx_credentials() -> Dict[str, str]:
""" """
Retrieves Watsonx API credentials from environmental variables. Retrieves Watsonx API credentials from environmental variables.
......
...@@ -40,7 +40,7 @@ try: ...@@ -40,7 +40,7 @@ try:
if parse_version(version("vllm")) >= parse_version("0.8.3"): if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template from vllm.entrypoints.chat_utils import resolve_hf_chat_template
except ModuleNotFoundError: except ModuleNotFoundError:
pass print("njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd")
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
......
...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor): ...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None self.indexes = None
class ACPGrammarParser(object): class ACPGrammarParser:
def __init__(self, task) -> None: def __init__(self, task) -> None:
self.task = task self.task = task
with open(GRAMMAR_FILE) as f: with open(GRAMMAR_FILE) as f:
...@@ -556,8 +556,8 @@ class STRIPS: ...@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret]) return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s): def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower()) d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower()) p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static()) new_state = get_atoms_pddl(d, p, s | self.get_static())
......
...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor): ...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None self.indexes = None
class ACPGrammarParser(object): class ACPGrammarParser:
def __init__(self, task) -> None: def __init__(self, task) -> None:
self.task = task self.task = task
with open(GRAMMAR_FILE) as f: with open(GRAMMAR_FILE) as f:
...@@ -556,8 +556,8 @@ class STRIPS: ...@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret]) return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s): def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower()) d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower()) p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static()) new_state = get_atoms_pddl(d, p, s | self.get_static())
......
...@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates ...@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from tqdm import tqdm from tqdm import tqdm
# from lm_eval.api.registry import ALL_TASKS # from lm_eval.api.registryv2 import ALL_TASKS
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment