Commit 93b2ab37 authored by Baber's avatar Baber
Browse files

refactor registry

parent de496b80
......@@ -4,8 +4,8 @@ import os
import random
import re
import string
from collections.abc import Iterable
from typing import Callable, List, Optional, Sequence, TypeVar
from collections.abc import Iterable, Sequence
from typing import Callable, List, Optional, TypeVar
import numpy as np
import sacrebleu
......
import logging
from typing import Callable, Dict, Union
from __future__ import annotations
import importlib
import inspect
import threading
from collections.abc import Iterable, Mapping, MutableMapping
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
from typing import (
Any,
Callable,
Generic,
TypeVar,
)
try: # Python≥3.10
import importlib.metadata as md
except ImportError: # pragma: no cover - fallback for 3.8/3.9 runtimes
import importlib_metadata as md # type: ignore
__all__ = [
"Registry",
"MetricSpec",
# concrete registries
"model_registry",
"task_registry",
"metric_registry",
"metric_agg_registry",
"higher_is_better_registry",
"filter_registry",
# helper
"freeze_all",
# Legacy compatibility
"DEFAULT_METRIC_REGISTRY",
"AGGREGATION_REGISTRY",
"register_model",
"get_model",
"register_task",
"get_task",
"register_metric",
"get_metric",
"register_metric_aggregation",
"get_metric_aggregation",
"register_higher_is_better",
"is_higher_better",
"register_filter",
"get_filter",
"register_aggregation",
"get_aggregation",
"MODEL_REGISTRY",
"TASK_REGISTRY",
"METRIC_REGISTRY",
"METRIC_AGGREGATION_REGISTRY",
"HIGHER_IS_BETTER_REGISTRY",
"FILTER_REGISTRY",
]
T = TypeVar("T")
# ────────────────────────────────────────────────────────────────────────
# Generic Registry
# ────────────────────────────────────────────────────────────────────────
class Registry(Generic[T]):
"""Name -> object mapping with decorator helpers and **lazy import** support."""
#: The underlying mutable mapping (might turn into MappingProxy on freeze)
_objects: MutableMapping[str, T | str | md.EntryPoint]
def __init__(
self,
name: str,
*,
base_cls: type[T] | None = None,
store: MutableMapping[str, T | str | md.EntryPoint] | None = None,
validator: Callable[[T], bool] | None = None,
) -> None:
self._name: str = name
self._base_cls: type[T] | None = base_cls
self._objects = store if store is not None else {}
self._metadata: dict[
str, dict[str, Any]
] = {} # Store metadata for each registered item
self._validator = validator # Custom validation function
self._lock = threading.RLock()
# ------------------------------------------------------------------
# Registration helpers (decorator or direct call)
# ------------------------------------------------------------------
def register(
self,
*aliases: str,
lazy: str | md.EntryPoint | None = None,
metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
def _do_register(target: T | str | md.EntryPoint) -> None:
if not aliases:
_aliases = (getattr(target, "__name__", str(target)),)
else:
_aliases = aliases
with self._lock:
for alias in _aliases:
if alias in self._objects:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing = self._objects[alias]
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
)
# Eager type check only when we have a concrete class
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} "
f"to be registered as a {self._name}"
)
self._objects[alias] = target
# Store metadata if provided
if metadata:
self._metadata[alias] = metadata
# ─── decorator path ───
def decorator(obj: T) -> T: # type: ignore[valid-type]
_do_register(obj)
return obj
# ─── direct‑call path with lazy placeholder ───
if lazy is not None:
_do_register(lazy)
return lambda x: x # no‑op decorator for accidental use
return decorator
def register_bulk(
self,
items: dict[str, T | str | md.EntryPoint],
metadata: dict[str, dict[str, Any]] | None = None,
) -> None:
"""Register multiple items at once.
Args:
items: Dictionary mapping aliases to objects/lazy paths
metadata: Optional dictionary mapping aliases to metadata
"""
with self._lock:
for alias, target in items.items():
if alias in self._objects:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing = self._objects[alias]
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
)
# Eager type check only when we have a concrete class
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} "
f"to be registered as a {self._name}"
)
self._objects[alias] = target
# Store metadata if provided
if metadata and alias in metadata:
self._metadata[alias] = metadata[alias]
# ------------------------------------------------------------------
# Lookup & materialisation
# ------------------------------------------------------------------
@lru_cache(maxsize=256) # Bounded cache to prevent memory growth
def _materialise(self, target: T | str | md.EntryPoint) -> T:
"""Import *target* if it is a dotted‑path string or EntryPoint."""
if isinstance(target, str):
mod, _, obj_name = target.partition(":")
if not _:
raise ValueError(
f"Lazy path '{target}' must be in 'module:object' form"
)
module = importlib.import_module(mod)
return getattr(module, obj_name)
if isinstance(target, md.EntryPoint):
return target.load()
return target # concrete already
def get(self, alias: str) -> T:
with self._lock:
try:
target = self._objects[alias]
except KeyError as exc:
raise KeyError(
f"Unknown {self._name} '{alias}'. Available: "
f"{', '.join(self._objects)}"
) from exc
# Only materialize if it's a string or EntryPoint (lazy placeholder)
if isinstance(target, (str, md.EntryPoint)):
concrete: T = self._materialise(target)
# First‑touch: swap placeholder with concrete obj for future calls
if concrete is not target:
self._objects[alias] = concrete
else:
# Already materialized, just return it
concrete = target
# Late type check (for placeholders)
if self._base_cls is not None and not issubclass(concrete, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{concrete} does not inherit from {self._base_cls} "
f"(registered under alias '{alias}')"
)
import evaluate as hf_evaluate
# Custom validation
if self._validator is not None and not self._validator(concrete):
raise ValueError(
f"{concrete} failed custom validation for {self._name} registry "
f"(registered under alias '{alias}')"
)
from lm_eval.api.model import LM
return concrete
# Mapping / dunder helpers -------------------------------------------------
eval_logger = logging.getLogger(__name__)
def __getitem__(self, alias: str) -> T: # noqa
return self.get(alias)
MODEL_REGISTRY = {}
def __iter__(self): # noqa
return iter(self._objects)
def __len__(self) -> int: # noqa
return len(self._objects)
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def items(self): # noqa
return self._objects.items()
def decorate(cls):
for name in names:
assert issubclass(cls, LM), (
f"Model '{name}' ({cls.__name__}) must extend LM class"
)
# Introspection -----------------------------------------------------------
assert name not in MODEL_REGISTRY, (
f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
)
def origin(self, alias: str) -> str | None:
obj = self._objects.get(alias)
try:
if isinstance(obj, str) or isinstance(obj, md.EntryPoint):
return None # placeholder - unknown until imported
file = inspect.getfile(obj) # type: ignore[arg-type]
line = inspect.getsourcelines(obj)[1] # type: ignore[arg-type]
return f"{file}:{line}"
except (
TypeError,
OSError,
AttributeError,
): # pragma: no cover - best-effort only
# TypeError: object not suitable for inspect
# OSError: file not found or accessible
# AttributeError: object lacks expected attributes
return None
MODEL_REGISTRY[name] = cls
return cls
def get_metadata(self, alias: str) -> dict[str, Any] | None:
"""Get metadata for a registered item."""
with self._lock:
return self._metadata.get(alias)
return decorate
# Mutability --------------------------------------------------------------
def freeze(self):
"""Make the registry *names* immutable (materialisation still works)."""
with self._lock:
if isinstance(self._objects, MappingProxyType):
return # already frozen
self._objects = MappingProxyType(dict(self._objects)) # type: ignore[assignment]
def get_model(model_name):
try:
return MODEL_REGISTRY[model_name]
except KeyError:
raise ValueError(
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
)
def clear(self):
"""Clear the registry (useful for tests). Cannot be called on frozen registries."""
with self._lock:
if isinstance(self._objects, MappingProxyType):
raise RuntimeError("Cannot clear a frozen registry")
self._objects.clear()
self._metadata.clear()
self._materialise.cache_clear() # type: ignore[attr-defined] # Added by lru_cache
TASK_REGISTRY = {}
GROUP_REGISTRY = {}
ALL_TASKS = set()
func2task_index = {}
# ────────────────────────────────────────────────────────────────────────
# Structured objects stored in registries
# ────────────────────────────────────────────────────────────────────────
def register_task(name):
def decorate(fn):
assert name not in TASK_REGISTRY, (
f"task named '{name}' conflicts with existing registered task!"
)
@dataclass(frozen=True)
class MetricSpec:
"""Bundle compute fn, aggregator, and *higher‑is‑better* flag."""
TASK_REGISTRY[name] = fn
ALL_TASKS.add(name)
func2task_index[fn.__name__] = name
return fn
compute: Callable[[Any, Any], Any]
aggregate: Callable[[Iterable[Any]], Mapping[str, float]]
higher_is_better: bool = True
output_type: str | None = None # e.g., "probability", "string", "numeric"
requires: list[str] | None = None # Dependencies on other metrics/data
return decorate
# ────────────────────────────────────────────────────────────────────────
# Concrete registries used by lm_eval
# ────────────────────────────────────────────────────────────────────────
def register_group(name):
def decorate(fn):
func_name = func2task_index[fn.__name__]
if name in GROUP_REGISTRY:
GROUP_REGISTRY[name].append(func_name)
else:
GROUP_REGISTRY[name] = [func_name]
ALL_TASKS.add(name)
return fn
return decorate
from lm_eval.api.model import LM # noqa: E402
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {}
FILTER_REGISTRY = {}
model_registry: Registry[type[LM]] = Registry("model", base_cls=LM)
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], Mapping[str, float]]] = (
Registry("metric aggregation")
)
higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag")
filter_registry: Registry[Callable] = Registry("filter")
# Default metric registry for output types
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [
"perplexity",
......@@ -90,107 +347,194 @@ DEFAULT_METRIC_REGISTRY = {
"generate_until": ["exact_match"],
}
# Aggregation registry (will be populated by register_aggregation)
AGGREGATION_REGISTRY: dict[str, Callable] = {}
# ────────────────────────────────────────────────────────────────────────
# Public helper aliases (legacy API)
# ────────────────────────────────────────────────────────────────────────
register_model = model_registry.register
get_model = model_registry.get
register_task = task_registry.register
get_task = task_registry.get
# Special handling for metric registration which uses different API
def register_metric(**kwargs):
"""Register a metric with metadata.
Compatible with old registry API that used keyword arguments.
"""
def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert "metric" in args
name = args["metric"]
for key, registry in [
("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
("aggregation", METRIC_AGGREGATION_REGISTRY),
]:
if key in args:
value = args[key]
assert value not in registry, (
f"{key} named '{value}' conflicts with existing registered {key}!"
metric_name = kwargs.get("metric")
if not metric_name:
raise ValueError("metric name is required")
# Create MetricSpec with the function and metadata
spec = MetricSpec(
compute=fn,
aggregate=lambda x: {}, # Default aggregation returns empty dict
higher_is_better=kwargs.get("higher_is_better", True),
output_type=kwargs.get("output_type"),
requires=kwargs.get("requires"),
)
# Register in metric registry
metric_registry._objects[metric_name] = spec
# Also handle aggregation if specified
if "aggregation" in kwargs:
agg_name = kwargs["aggregation"]
# Try to get aggregation from AGGREGATION_REGISTRY
if agg_name in AGGREGATION_REGISTRY:
spec = MetricSpec(
compute=fn,
aggregate=AGGREGATION_REGISTRY[agg_name],
higher_is_better=kwargs.get("higher_is_better", True),
output_type=kwargs.get("output_type"),
requires=kwargs.get("requires"),
)
metric_registry._objects[metric_name] = spec
if key == "metric":
registry[name] = fn
elif key == "aggregation":
registry[name] = AGGREGATION_REGISTRY[value]
else:
registry[name] = value
# Handle higher_is_better registry
if "higher_is_better" in kwargs:
higher_is_better_registry._objects[metric_name] = kwargs["higher_is_better"]
return fn
return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
def get_metric(name: str, hf_evaluate_metric=False):
"""Get a metric by name, with fallback to HF evaluate."""
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
else:
eval_logger.warning(
try:
spec = metric_registry.get(name)
if isinstance(spec, MetricSpec):
return spec.compute
return spec
except KeyError:
import logging
logging.getLogger(__name__).warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
)
# Fallback to HF evaluate
try:
import evaluate as hf_evaluate
metric_object = hf_evaluate.load(name)
return metric_object.compute
except Exception:
eval_logger.error(
import logging
logging.getLogger(__name__).error(
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
)
return None
def register_aggregation(name: str):
def decorate(fn):
assert name not in AGGREGATION_REGISTRY, (
f"aggregation named '{name}' conflicts with existing registered aggregation!"
)
register_metric_aggregation = metric_agg_registry.register
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_metric_aggregation(metric_name: str):
"""Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation)
if metric_name in metric_registry._objects:
metric_spec = metric_registry._objects[metric_name]
if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate:
return metric_spec.aggregate
# Fall back to metric_agg_registry (for standalone aggregations)
if metric_name in metric_agg_registry._objects:
return metric_agg_registry._objects[metric_name]
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
# If not found, raise error
raise KeyError(
f"Unknown metric aggregation '{metric_name}'. Available: {list(AGGREGATION_REGISTRY.keys())}"
)
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
register_higher_is_better = higher_is_better_registry.register
is_higher_better = higher_is_better_registry.get
register_filter = filter_registry.register
get_filter = filter_registry.get
def is_higher_better(metric_name) -> bool:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!"
)
def register_filter(name):
def decorate(cls):
if name in FILTER_REGISTRY:
eval_logger.info(
f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
# Special handling for AGGREGATION_REGISTRY which works differently
def register_aggregation(name: str):
def decorate(fn):
if name in AGGREGATION_REGISTRY:
raise ValueError(
f"aggregation named '{name}' conflicts with existing registered aggregation!"
)
FILTER_REGISTRY[name] = cls
return cls
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_filter(filter_name: Union[str, Callable]) -> Callable:
def get_aggregation(name: str) -> Callable[[], dict[str, Callable]]:
try:
return FILTER_REGISTRY[filter_name]
except KeyError as e:
if callable(filter_name):
return filter_name
else:
eval_logger.warning(f"filter `{filter_name}` is not registered!")
raise e
return AGGREGATION_REGISTRY[name]
except KeyError:
import logging
logging.getLogger(__name__).warning(
f"{name} not a registered aggregation metric!"
)
return None
# ────────────────────────────────────────────────────────────────────────
# Optional PyPI entry‑point discovery - uncomment if desired
# ────────────────────────────────────────────────────────────────────────
# for _group, _reg in {
# "lm_eval.models": model_registry,
# "lm_eval.tasks": task_registry,
# "lm_eval.metrics": metric_registry,
# }.items():
# for _ep in md.entry_points(group=_group):
# _reg.register(_ep.name, lazy=_ep)
# ────────────────────────────────────────────────────────────────────────
# Convenience
# ────────────────────────────────────────────────────────────────────────
def freeze_all() -> None: # pragma: no cover
"""Freeze every global registry (idempotent)."""
for _reg in (
model_registry,
task_registry,
metric_registry,
metric_agg_registry,
higher_is_better_registry,
filter_registry,
):
_reg.freeze()
# ────────────────────────────────────────────────────────────────────────
# Backwards‑compatibility read‑only globals
# ────────────────────────────────────────────────────────────────────────
MODEL_REGISTRY: Mapping[str, type[LM]] = MappingProxyType(model_registry._objects) # type: ignore[attr-defined]
TASK_REGISTRY: Mapping[str, Callable[..., Any]] = MappingProxyType(
task_registry._objects
) # type: ignore[attr-defined]
METRIC_REGISTRY: Mapping[str, MetricSpec] = MappingProxyType(metric_registry._objects) # type: ignore[attr-defined]
METRIC_AGGREGATION_REGISTRY: Mapping[str, Callable] = MappingProxyType(
metric_agg_registry._objects
) # type: ignore[attr-defined]
HIGHER_IS_BETTER_REGISTRY: Mapping[str, bool] = MappingProxyType(
higher_is_better_registry._objects
) # type: ignore[attr-defined]
FILTER_REGISTRY: Mapping[str, Callable] = MappingProxyType(filter_registry._objects) # type: ignore[attr-defined]
......@@ -3,18 +3,15 @@ import ast
import logging
import random
import re
from collections.abc import Callable
from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Union,
......@@ -1774,7 +1771,7 @@ class MultipleChoiceTask(Task):
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice)),
arguments=(ctx, f" {choice}"),
idx=i,
**kwargs,
)
......
from . import (
anthropic_llms,
api_models,
dummy,
gguf,
hf_audiolm,
hf_steered,
hf_vlms,
huggingface,
ibm_watsonx_ai,
mamba_lm,
nemo_lm,
neuron_optimum,
openai_completions,
optimum_ipex,
optimum_lm,
sglang_causallms,
sglang_generate_API,
textsynth,
vllm_causallms,
vllm_vlms,
)
# TODO: implement __all__
# Models are now lazily loaded via the registry system
# No need to import them all at once
# Define model mappings for lazy registration
MODEL_MAPPING = {
"anthropic-completions": "lm_eval.models.anthropic_llms:AnthropicLM",
"anthropic-chat": "lm_eval.models.anthropic_llms:AnthropicChatLM",
"anthropic-chat-completions": "lm_eval.models.anthropic_llms:AnthropicCompletionsLM",
"local-completions": "lm_eval.models.openai_completions:LocalCompletionsAPI",
"local-chat-completions": "lm_eval.models.openai_completions:LocalChatCompletion",
"openai-completions": "lm_eval.models.openai_completions:OpenAICompletionsAPI",
"openai-chat-completions": "lm_eval.models.openai_completions:OpenAIChatCompletion",
"dummy": "lm_eval.models.dummy:DummyLM",
"gguf": "lm_eval.models.gguf:GGUFLM",
"ggml": "lm_eval.models.gguf:GGUFLM",
"hf-audiolm-qwen": "lm_eval.models.hf_audiolm:HFAudioLM",
"steered": "lm_eval.models.hf_steered:SteeredHF",
"hf-multimodal": "lm_eval.models.hf_vlms:HFMultimodalLM",
"hf-auto": "lm_eval.models.huggingface:HFLM",
"hf": "lm_eval.models.huggingface:HFLM",
"huggingface": "lm_eval.models.huggingface:HFLM",
"watsonx_llm": "lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI",
"mamba_ssm": "lm_eval.models.mamba_lm:MambaLMWrapper",
"nemo_lm": "lm_eval.models.nemo_lm:NeMoLM",
"neuronx": "lm_eval.models.neuron_optimum:NeuronModelForCausalLM",
"ipex": "lm_eval.models.optimum_ipex:IPEXForCausalLM",
"openvino": "lm_eval.models.optimum_lm:OptimumLM",
"sglang": "lm_eval.models.sglang_causallms:SGLANG",
"sglang-generate": "lm_eval.models.sglang_generate_API:SGAPI",
"textsynth": "lm_eval.models.textsynth:TextSynthLM",
"vllm": "lm_eval.models.vllm_causallms:VLLM",
"vllm-vlm": "lm_eval.models.vllm_vlms:VLLM_VLM",
}
# Register all models lazily
def _register_all_models():
"""Register all known models lazily in the registry."""
from lm_eval.api.registry import model_registry
for name, path in MODEL_MAPPING.items():
# Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry:
# Call register with the lazy parameter, returns a decorator
model_registry.register(name, lazy=path)(None)
# Call registration on module import
_register_all_models()
__all__ = ["MODEL_MAPPING"]
try:
......
from collections.abc import Generator
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Any, Callable, Generator, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
from peft.peft_model import PeftModel
......
......@@ -3,7 +3,7 @@ import json
import logging
import os
import warnings
from functools import lru_cache
from functools import cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm
......@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise ValueError(error_msg)
@lru_cache(maxsize=None)
@cache
def get_watsonx_credentials() -> Dict[str, str]:
"""
Retrieves Watsonx API credentials from environmental variables.
......
......@@ -40,7 +40,7 @@ try:
if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template
except ModuleNotFoundError:
pass
print("njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd")
if TYPE_CHECKING:
pass
......
......@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None
class ACPGrammarParser(object):
class ACPGrammarParser:
def __init__(self, task) -> None:
self.task = task
with open(GRAMMAR_FILE) as f:
......@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower())
d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static())
......
......@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None
class ACPGrammarParser(object):
class ACPGrammarParser:
def __init__(self, task) -> None:
self.task = task
with open(GRAMMAR_FILE) as f:
......@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower())
d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static())
......
......@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from tqdm import tqdm
# from lm_eval.api.registry import ALL_TASKS
# from lm_eval.api.registryv2 import ALL_TASKS
eval_logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment