Commit 93b2ab37 authored by Baber's avatar Baber
Browse files

refactor registry

parent de496b80
...@@ -4,8 +4,8 @@ import os ...@@ -4,8 +4,8 @@ import os
import random import random
import re import re
import string import string
from collections.abc import Iterable from collections.abc import Iterable, Sequence
from typing import Callable, List, Optional, Sequence, TypeVar from typing import Callable, List, Optional, TypeVar
import numpy as np import numpy as np
import sacrebleu import sacrebleu
......
import logging from __future__ import annotations
from typing import Callable, Dict, Union
import importlib
import evaluate as hf_evaluate import inspect
import threading
from lm_eval.api.model import LM from collections.abc import Iterable, Mapping, MutableMapping
from dataclasses import dataclass
from functools import lru_cache
eval_logger = logging.getLogger(__name__) from types import MappingProxyType
from typing import (
MODEL_REGISTRY = {} Any,
Callable,
Generic,
def register_model(*names): TypeVar,
# either pass a list or a single alias. )
# function receives them as a tuple of strings
def decorate(cls): try: # Python≥3.10
for name in names: import importlib.metadata as md
assert issubclass(cls, LM), ( except ImportError: # pragma: no cover - fallback for 3.8/3.9 runtimes
f"Model '{name}' ({cls.__name__}) must extend LM class" import importlib_metadata as md # type: ignore
__all__ = [
"Registry",
"MetricSpec",
# concrete registries
"model_registry",
"task_registry",
"metric_registry",
"metric_agg_registry",
"higher_is_better_registry",
"filter_registry",
# helper
"freeze_all",
# Legacy compatibility
"DEFAULT_METRIC_REGISTRY",
"AGGREGATION_REGISTRY",
"register_model",
"get_model",
"register_task",
"get_task",
"register_metric",
"get_metric",
"register_metric_aggregation",
"get_metric_aggregation",
"register_higher_is_better",
"is_higher_better",
"register_filter",
"get_filter",
"register_aggregation",
"get_aggregation",
"MODEL_REGISTRY",
"TASK_REGISTRY",
"METRIC_REGISTRY",
"METRIC_AGGREGATION_REGISTRY",
"HIGHER_IS_BETTER_REGISTRY",
"FILTER_REGISTRY",
]
T = TypeVar("T")
# ────────────────────────────────────────────────────────────────────────
# Generic Registry
# ────────────────────────────────────────────────────────────────────────
class Registry(Generic[T]):
"""Name -> object mapping with decorator helpers and **lazy import** support."""
#: The underlying mutable mapping (might turn into MappingProxy on freeze)
_objects: MutableMapping[str, T | str | md.EntryPoint]
def __init__(
self,
name: str,
*,
base_cls: type[T] | None = None,
store: MutableMapping[str, T | str | md.EntryPoint] | None = None,
validator: Callable[[T], bool] | None = None,
) -> None:
self._name: str = name
self._base_cls: type[T] | None = base_cls
self._objects = store if store is not None else {}
self._metadata: dict[
str, dict[str, Any]
] = {} # Store metadata for each registered item
self._validator = validator # Custom validation function
self._lock = threading.RLock()
# ------------------------------------------------------------------
# Registration helpers (decorator or direct call)
# ------------------------------------------------------------------
def register(
self,
*aliases: str,
lazy: str | md.EntryPoint | None = None,
metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
def _do_register(target: T | str | md.EntryPoint) -> None:
if not aliases:
_aliases = (getattr(target, "__name__", str(target)),)
else:
_aliases = aliases
with self._lock:
for alias in _aliases:
if alias in self._objects:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing = self._objects[alias]
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
)
# Eager type check only when we have a concrete class
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} "
f"to be registered as a {self._name}"
)
self._objects[alias] = target
# Store metadata if provided
if metadata:
self._metadata[alias] = metadata
# ─── decorator path ───
def decorator(obj: T) -> T: # type: ignore[valid-type]
_do_register(obj)
return obj
# ─── direct‑call path with lazy placeholder ───
if lazy is not None:
_do_register(lazy)
return lambda x: x # no‑op decorator for accidental use
return decorator
def register_bulk(
self,
items: dict[str, T | str | md.EntryPoint],
metadata: dict[str, dict[str, Any]] | None = None,
) -> None:
"""Register multiple items at once.
Args:
items: Dictionary mapping aliases to objects/lazy paths
metadata: Optional dictionary mapping aliases to metadata
"""
with self._lock:
for alias, target in items.items():
if alias in self._objects:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing = self._objects[alias]
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
) )
assert name not in MODEL_REGISTRY, ( # Eager type check only when we have a concrete class
f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead." if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} "
f"to be registered as a {self._name}"
) )
MODEL_REGISTRY[name] = cls self._objects[alias] = target
return cls
return decorate # Store metadata if provided
if metadata and alias in metadata:
self._metadata[alias] = metadata[alias]
# ------------------------------------------------------------------
# Lookup & materialisation
# ------------------------------------------------------------------
def get_model(model_name): @lru_cache(maxsize=256) # Bounded cache to prevent memory growth
try: def _materialise(self, target: T | str | md.EntryPoint) -> T:
return MODEL_REGISTRY[model_name] """Import *target* if it is a dotted‑path string or EntryPoint."""
except KeyError: if isinstance(target, str):
mod, _, obj_name = target.partition(":")
if not _:
raise ValueError( raise ValueError(
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}" f"Lazy path '{target}' must be in 'module:object' form"
)
module = importlib.import_module(mod)
return getattr(module, obj_name)
if isinstance(target, md.EntryPoint):
return target.load()
return target # concrete already
def get(self, alias: str) -> T:
with self._lock:
try:
target = self._objects[alias]
except KeyError as exc:
raise KeyError(
f"Unknown {self._name} '{alias}'. Available: "
f"{', '.join(self._objects)}"
) from exc
# Only materialize if it's a string or EntryPoint (lazy placeholder)
if isinstance(target, (str, md.EntryPoint)):
concrete: T = self._materialise(target)
# First‑touch: swap placeholder with concrete obj for future calls
if concrete is not target:
self._objects[alias] = concrete
else:
# Already materialized, just return it
concrete = target
# Late type check (for placeholders)
if self._base_cls is not None and not issubclass(concrete, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{concrete} does not inherit from {self._base_cls} "
f"(registered under alias '{alias}')"
) )
# Custom validation
TASK_REGISTRY = {} if self._validator is not None and not self._validator(concrete):
GROUP_REGISTRY = {} raise ValueError(
ALL_TASKS = set() f"{concrete} failed custom validation for {self._name} registry "
func2task_index = {} f"(registered under alias '{alias}')"
def register_task(name):
def decorate(fn):
assert name not in TASK_REGISTRY, (
f"task named '{name}' conflicts with existing registered task!"
) )
TASK_REGISTRY[name] = fn return concrete
ALL_TASKS.add(name)
func2task_index[fn.__name__] = name
return fn
return decorate # Mapping / dunder helpers -------------------------------------------------
def __getitem__(self, alias: str) -> T: # noqa
return self.get(alias)
def register_group(name): def __iter__(self): # noqa
def decorate(fn): return iter(self._objects)
func_name = func2task_index[fn.__name__]
if name in GROUP_REGISTRY:
GROUP_REGISTRY[name].append(func_name)
else:
GROUP_REGISTRY[name] = [func_name]
ALL_TASKS.add(name)
return fn
return decorate def __len__(self) -> int: # noqa
return len(self._objects)
def items(self): # noqa
return self._objects.items()
OUTPUT_TYPE_REGISTRY = {} # Introspection -----------------------------------------------------------
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {}
FILTER_REGISTRY = {}
def origin(self, alias: str) -> str | None:
obj = self._objects.get(alias)
try:
if isinstance(obj, str) or isinstance(obj, md.EntryPoint):
return None # placeholder - unknown until imported
file = inspect.getfile(obj) # type: ignore[arg-type]
line = inspect.getsourcelines(obj)[1] # type: ignore[arg-type]
return f"{file}:{line}"
except (
TypeError,
OSError,
AttributeError,
): # pragma: no cover - best-effort only
# TypeError: object not suitable for inspect
# OSError: file not found or accessible
# AttributeError: object lacks expected attributes
return None
def get_metadata(self, alias: str) -> dict[str, Any] | None:
"""Get metadata for a registered item."""
with self._lock:
return self._metadata.get(alias)
# Mutability --------------------------------------------------------------
def freeze(self):
"""Make the registry *names* immutable (materialisation still works)."""
with self._lock:
if isinstance(self._objects, MappingProxyType):
return # already frozen
self._objects = MappingProxyType(dict(self._objects)) # type: ignore[assignment]
def clear(self):
"""Clear the registry (useful for tests). Cannot be called on frozen registries."""
with self._lock:
if isinstance(self._objects, MappingProxyType):
raise RuntimeError("Cannot clear a frozen registry")
self._objects.clear()
self._metadata.clear()
self._materialise.cache_clear() # type: ignore[attr-defined] # Added by lru_cache
# ────────────────────────────────────────────────────────────────────────
# Structured objects stored in registries
# ────────────────────────────────────────────────────────────────────────
@dataclass(frozen=True)
class MetricSpec:
"""Bundle compute fn, aggregator, and *higher‑is‑better* flag."""
compute: Callable[[Any, Any], Any]
aggregate: Callable[[Iterable[Any]], Mapping[str, float]]
higher_is_better: bool = True
output_type: str | None = None # e.g., "probability", "string", "numeric"
requires: list[str] | None = None # Dependencies on other metrics/data
# ────────────────────────────────────────────────────────────────────────
# Concrete registries used by lm_eval
# ────────────────────────────────────────────────────────────────────────
from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = Registry("model", base_cls=LM)
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], Mapping[str, float]]] = (
Registry("metric aggregation")
)
higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag")
filter_registry: Registry[Callable] = Registry("filter")
# Default metric registry for output types
DEFAULT_METRIC_REGISTRY = { DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [ "loglikelihood": [
"perplexity", "perplexity",
...@@ -90,107 +347,194 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -90,107 +347,194 @@ DEFAULT_METRIC_REGISTRY = {
"generate_until": ["exact_match"], "generate_until": ["exact_match"],
} }
# Aggregation registry (will be populated by register_aggregation)
AGGREGATION_REGISTRY: dict[str, Callable] = {}
# ────────────────────────────────────────────────────────────────────────
# Public helper aliases (legacy API)
# ────────────────────────────────────────────────────────────────────────
register_model = model_registry.register
get_model = model_registry.get
register_task = task_registry.register
get_task = task_registry.get
# Special handling for metric registration which uses different API
def register_metric(**kwargs):
"""Register a metric with metadata.
Compatible with old registry API that used keyword arguments.
"""
def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn): def decorate(fn):
assert "metric" in args metric_name = kwargs.get("metric")
name = args["metric"] if not metric_name:
raise ValueError("metric name is required")
for key, registry in [
("metric", METRIC_REGISTRY), # Create MetricSpec with the function and metadata
("higher_is_better", HIGHER_IS_BETTER_REGISTRY), spec = MetricSpec(
("aggregation", METRIC_AGGREGATION_REGISTRY), compute=fn,
]: aggregate=lambda x: {}, # Default aggregation returns empty dict
if key in args: higher_is_better=kwargs.get("higher_is_better", True),
value = args[key] output_type=kwargs.get("output_type"),
assert value not in registry, ( requires=kwargs.get("requires"),
f"{key} named '{value}' conflicts with existing registered {key}!"
) )
if key == "metric": # Register in metric registry
registry[name] = fn metric_registry._objects[metric_name] = spec
elif key == "aggregation":
registry[name] = AGGREGATION_REGISTRY[value] # Also handle aggregation if specified
else: if "aggregation" in kwargs:
registry[name] = value agg_name = kwargs["aggregation"]
# Try to get aggregation from AGGREGATION_REGISTRY
if agg_name in AGGREGATION_REGISTRY:
spec = MetricSpec(
compute=fn,
aggregate=AGGREGATION_REGISTRY[agg_name],
higher_is_better=kwargs.get("higher_is_better", True),
output_type=kwargs.get("output_type"),
requires=kwargs.get("requires"),
)
metric_registry._objects[metric_name] = spec
# Handle higher_is_better registry
if "higher_is_better" in kwargs:
higher_is_better_registry._objects[metric_name] = kwargs["higher_is_better"]
return fn return fn
return decorate return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Callable: def get_metric(name: str, hf_evaluate_metric=False):
"""Get a metric by name, with fallback to HF evaluate."""
if not hf_evaluate_metric: if not hf_evaluate_metric:
if name in METRIC_REGISTRY: try:
return METRIC_REGISTRY[name] spec = metric_registry.get(name)
else: if isinstance(spec, MetricSpec):
eval_logger.warning( return spec.compute
return spec
except KeyError:
import logging
logging.getLogger(__name__).warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..." f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
) )
# Fallback to HF evaluate
try: try:
import evaluate as hf_evaluate
metric_object = hf_evaluate.load(name) metric_object = hf_evaluate.load(name)
return metric_object.compute return metric_object.compute
except Exception: except Exception:
eval_logger.error( import logging
logging.getLogger(__name__).error(
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric", f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
) )
return None
def register_aggregation(name: str): register_metric_aggregation = metric_agg_registry.register
def decorate(fn):
assert name not in AGGREGATION_REGISTRY, (
f"aggregation named '{name}' conflicts with existing registered aggregation!"
)
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_metric_aggregation(metric_name: str):
"""Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation)
if metric_name in metric_registry._objects:
metric_spec = metric_registry._objects[metric_name]
if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate:
return metric_spec.aggregate
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: # Fall back to metric_agg_registry (for standalone aggregations)
try: if metric_name in metric_agg_registry._objects:
return AGGREGATION_REGISTRY[name] return metric_agg_registry._objects[metric_name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
# If not found, raise error
raise KeyError(
f"Unknown metric aggregation '{metric_name}'. Available: {list(AGGREGATION_REGISTRY.keys())}"
)
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
register_higher_is_better = higher_is_better_registry.register
is_higher_better = higher_is_better_registry.get
def is_higher_better(metric_name) -> bool: register_filter = filter_registry.register
try: get_filter = filter_registry.get
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!"
)
def register_filter(name): # Special handling for AGGREGATION_REGISTRY which works differently
def decorate(cls): def register_aggregation(name: str):
if name in FILTER_REGISTRY: def decorate(fn):
eval_logger.info( if name in AGGREGATION_REGISTRY:
f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}" raise ValueError(
f"aggregation named '{name}' conflicts with existing registered aggregation!"
) )
FILTER_REGISTRY[name] = cls AGGREGATION_REGISTRY[name] = fn
return cls return fn
return decorate return decorate
def get_filter(filter_name: Union[str, Callable]) -> Callable: def get_aggregation(name: str) -> Callable[[], dict[str, Callable]]:
try: try:
return FILTER_REGISTRY[filter_name] return AGGREGATION_REGISTRY[name]
except KeyError as e: except KeyError:
if callable(filter_name): import logging
return filter_name
else: logging.getLogger(__name__).warning(
eval_logger.warning(f"filter `{filter_name}` is not registered!") f"{name} not a registered aggregation metric!"
raise e )
return None
# ────────────────────────────────────────────────────────────────────────
# Optional PyPI entry‑point discovery - uncomment if desired
# ────────────────────────────────────────────────────────────────────────
# for _group, _reg in {
# "lm_eval.models": model_registry,
# "lm_eval.tasks": task_registry,
# "lm_eval.metrics": metric_registry,
# }.items():
# for _ep in md.entry_points(group=_group):
# _reg.register(_ep.name, lazy=_ep)
# ────────────────────────────────────────────────────────────────────────
# Convenience
# ────────────────────────────────────────────────────────────────────────
def freeze_all() -> None: # pragma: no cover
"""Freeze every global registry (idempotent)."""
for _reg in (
model_registry,
task_registry,
metric_registry,
metric_agg_registry,
higher_is_better_registry,
filter_registry,
):
_reg.freeze()
# ────────────────────────────────────────────────────────────────────────
# Backwards‑compatibility read‑only globals
# ────────────────────────────────────────────────────────────────────────
MODEL_REGISTRY: Mapping[str, type[LM]] = MappingProxyType(model_registry._objects) # type: ignore[attr-defined]
TASK_REGISTRY: Mapping[str, Callable[..., Any]] = MappingProxyType(
task_registry._objects
) # type: ignore[attr-defined]
METRIC_REGISTRY: Mapping[str, MetricSpec] = MappingProxyType(metric_registry._objects) # type: ignore[attr-defined]
METRIC_AGGREGATION_REGISTRY: Mapping[str, Callable] = MappingProxyType(
metric_agg_registry._objects
) # type: ignore[attr-defined]
HIGHER_IS_BETTER_REGISTRY: Mapping[str, bool] = MappingProxyType(
higher_is_better_registry._objects
) # type: ignore[attr-defined]
FILTER_REGISTRY: Mapping[str, Callable] = MappingProxyType(filter_registry._objects) # type: ignore[attr-defined]
...@@ -3,18 +3,15 @@ import ast ...@@ -3,18 +3,15 @@ import ast
import logging import logging
import random import random
import re import re
from collections.abc import Callable from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import ( from typing import (
Any, Any,
Dict, Dict,
Iterable,
Iterator,
List, List,
Literal, Literal,
Mapping,
Optional, Optional,
Tuple, Tuple,
Union, Union,
...@@ -1774,7 +1771,7 @@ class MultipleChoiceTask(Task): ...@@ -1774,7 +1771,7 @@ class MultipleChoiceTask(Task):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, f" {choice}"),
idx=i, idx=i,
**kwargs, **kwargs,
) )
......
from . import ( # Models are now lazily loaded via the registry system
anthropic_llms, # No need to import them all at once
api_models,
dummy, # Define model mappings for lazy registration
gguf, MODEL_MAPPING = {
hf_audiolm, "anthropic-completions": "lm_eval.models.anthropic_llms:AnthropicLM",
hf_steered, "anthropic-chat": "lm_eval.models.anthropic_llms:AnthropicChatLM",
hf_vlms, "anthropic-chat-completions": "lm_eval.models.anthropic_llms:AnthropicCompletionsLM",
huggingface, "local-completions": "lm_eval.models.openai_completions:LocalCompletionsAPI",
ibm_watsonx_ai, "local-chat-completions": "lm_eval.models.openai_completions:LocalChatCompletion",
mamba_lm, "openai-completions": "lm_eval.models.openai_completions:OpenAICompletionsAPI",
nemo_lm, "openai-chat-completions": "lm_eval.models.openai_completions:OpenAIChatCompletion",
neuron_optimum, "dummy": "lm_eval.models.dummy:DummyLM",
openai_completions, "gguf": "lm_eval.models.gguf:GGUFLM",
optimum_ipex, "ggml": "lm_eval.models.gguf:GGUFLM",
optimum_lm, "hf-audiolm-qwen": "lm_eval.models.hf_audiolm:HFAudioLM",
sglang_causallms, "steered": "lm_eval.models.hf_steered:SteeredHF",
sglang_generate_API, "hf-multimodal": "lm_eval.models.hf_vlms:HFMultimodalLM",
textsynth, "hf-auto": "lm_eval.models.huggingface:HFLM",
vllm_causallms, "hf": "lm_eval.models.huggingface:HFLM",
vllm_vlms, "huggingface": "lm_eval.models.huggingface:HFLM",
) "watsonx_llm": "lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI",
"mamba_ssm": "lm_eval.models.mamba_lm:MambaLMWrapper",
"nemo_lm": "lm_eval.models.nemo_lm:NeMoLM",
# TODO: implement __all__ "neuronx": "lm_eval.models.neuron_optimum:NeuronModelForCausalLM",
"ipex": "lm_eval.models.optimum_ipex:IPEXForCausalLM",
"openvino": "lm_eval.models.optimum_lm:OptimumLM",
"sglang": "lm_eval.models.sglang_causallms:SGLANG",
"sglang-generate": "lm_eval.models.sglang_generate_API:SGAPI",
"textsynth": "lm_eval.models.textsynth:TextSynthLM",
"vllm": "lm_eval.models.vllm_causallms:VLLM",
"vllm-vlm": "lm_eval.models.vllm_vlms:VLLM_VLM",
}
# Register all models lazily
def _register_all_models():
"""Register all known models lazily in the registry."""
from lm_eval.api.registry import model_registry
for name, path in MODEL_MAPPING.items():
# Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry:
# Call register with the lazy parameter, returns a decorator
model_registry.register(name, lazy=path)(None)
# Call registration on module import
_register_all_models()
__all__ = ["MODEL_MAPPING"]
try: try:
......
from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
from peft.peft_model import PeftModel from peft.peft_model import PeftModel
......
...@@ -3,7 +3,7 @@ import json ...@@ -3,7 +3,7 @@ import json
import logging import logging
import os import os
import warnings import warnings
from functools import lru_cache from functools import cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm from tqdm import tqdm
...@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None: ...@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise ValueError(error_msg) raise ValueError(error_msg)
@lru_cache(maxsize=None) @cache
def get_watsonx_credentials() -> Dict[str, str]: def get_watsonx_credentials() -> Dict[str, str]:
""" """
Retrieves Watsonx API credentials from environmental variables. Retrieves Watsonx API credentials from environmental variables.
......
...@@ -40,7 +40,7 @@ try: ...@@ -40,7 +40,7 @@ try:
if parse_version(version("vllm")) >= parse_version("0.8.3"): if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template from vllm.entrypoints.chat_utils import resolve_hf_chat_template
except ModuleNotFoundError: except ModuleNotFoundError:
pass print("njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd")
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
......
...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor): ...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None self.indexes = None
class ACPGrammarParser(object): class ACPGrammarParser:
def __init__(self, task) -> None: def __init__(self, task) -> None:
self.task = task self.task = task
with open(GRAMMAR_FILE) as f: with open(GRAMMAR_FILE) as f:
...@@ -556,8 +556,8 @@ class STRIPS: ...@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret]) return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s): def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower()) d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower()) p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static()) new_state = get_atoms_pddl(d, p, s | self.get_static())
......
...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor): ...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None self.indexes = None
class ACPGrammarParser(object): class ACPGrammarParser:
def __init__(self, task) -> None: def __init__(self, task) -> None:
self.task = task self.task = task
with open(GRAMMAR_FILE) as f: with open(GRAMMAR_FILE) as f:
...@@ -556,8 +556,8 @@ class STRIPS: ...@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret]) return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s): def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower()) d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower()) p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static()) new_state = get_atoms_pddl(d, p, s | self.get_static())
......
...@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates ...@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from tqdm import tqdm from tqdm import tqdm
# from lm_eval.api.registry import ALL_TASKS # from lm_eval.api.registryv2 import ALL_TASKS
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment