Commit 93b2ab37 authored by Baber's avatar Baber
Browse files

refactor registry

parent de496b80
...@@ -4,8 +4,8 @@ import os ...@@ -4,8 +4,8 @@ import os
import random import random
import re import re
import string import string
from collections.abc import Iterable from collections.abc import Iterable, Sequence
from typing import Callable, List, Optional, Sequence, TypeVar from typing import Callable, List, Optional, TypeVar
import numpy as np import numpy as np
import sacrebleu import sacrebleu
......
This diff is collapsed.
...@@ -3,18 +3,15 @@ import ast ...@@ -3,18 +3,15 @@ import ast
import logging import logging
import random import random
import re import re
from collections.abc import Callable from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import ( from typing import (
Any, Any,
Dict, Dict,
Iterable,
Iterator,
List, List,
Literal, Literal,
Mapping,
Optional, Optional,
Tuple, Tuple,
Union, Union,
...@@ -1774,7 +1771,7 @@ class MultipleChoiceTask(Task): ...@@ -1774,7 +1771,7 @@ class MultipleChoiceTask(Task):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, f" {choice}"),
idx=i, idx=i,
**kwargs, **kwargs,
) )
......
from . import ( # Models are now lazily loaded via the registry system
anthropic_llms, # No need to import them all at once
api_models,
dummy, # Define model mappings for lazy registration
gguf, MODEL_MAPPING = {
hf_audiolm, "anthropic-completions": "lm_eval.models.anthropic_llms:AnthropicLM",
hf_steered, "anthropic-chat": "lm_eval.models.anthropic_llms:AnthropicChatLM",
hf_vlms, "anthropic-chat-completions": "lm_eval.models.anthropic_llms:AnthropicCompletionsLM",
huggingface, "local-completions": "lm_eval.models.openai_completions:LocalCompletionsAPI",
ibm_watsonx_ai, "local-chat-completions": "lm_eval.models.openai_completions:LocalChatCompletion",
mamba_lm, "openai-completions": "lm_eval.models.openai_completions:OpenAICompletionsAPI",
nemo_lm, "openai-chat-completions": "lm_eval.models.openai_completions:OpenAIChatCompletion",
neuron_optimum, "dummy": "lm_eval.models.dummy:DummyLM",
openai_completions, "gguf": "lm_eval.models.gguf:GGUFLM",
optimum_ipex, "ggml": "lm_eval.models.gguf:GGUFLM",
optimum_lm, "hf-audiolm-qwen": "lm_eval.models.hf_audiolm:HFAudioLM",
sglang_causallms, "steered": "lm_eval.models.hf_steered:SteeredHF",
sglang_generate_API, "hf-multimodal": "lm_eval.models.hf_vlms:HFMultimodalLM",
textsynth, "hf-auto": "lm_eval.models.huggingface:HFLM",
vllm_causallms, "hf": "lm_eval.models.huggingface:HFLM",
vllm_vlms, "huggingface": "lm_eval.models.huggingface:HFLM",
) "watsonx_llm": "lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI",
"mamba_ssm": "lm_eval.models.mamba_lm:MambaLMWrapper",
"nemo_lm": "lm_eval.models.nemo_lm:NeMoLM",
# TODO: implement __all__ "neuronx": "lm_eval.models.neuron_optimum:NeuronModelForCausalLM",
"ipex": "lm_eval.models.optimum_ipex:IPEXForCausalLM",
"openvino": "lm_eval.models.optimum_lm:OptimumLM",
"sglang": "lm_eval.models.sglang_causallms:SGLANG",
"sglang-generate": "lm_eval.models.sglang_generate_API:SGAPI",
"textsynth": "lm_eval.models.textsynth:TextSynthLM",
"vllm": "lm_eval.models.vllm_causallms:VLLM",
"vllm-vlm": "lm_eval.models.vllm_vlms:VLLM_VLM",
}
# Register all models lazily
def _register_all_models():
"""Register all known models lazily in the registry."""
from lm_eval.api.registry import model_registry
for name, path in MODEL_MAPPING.items():
# Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry:
# Call register with the lazy parameter, returns a decorator
model_registry.register(name, lazy=path)(None)
# Call registration on module import
_register_all_models()
__all__ = ["MODEL_MAPPING"]
try: try:
......
from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
from peft.peft_model import PeftModel from peft.peft_model import PeftModel
......
...@@ -3,7 +3,7 @@ import json ...@@ -3,7 +3,7 @@ import json
import logging import logging
import os import os
import warnings import warnings
from functools import lru_cache from functools import cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm from tqdm import tqdm
...@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None: ...@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise ValueError(error_msg) raise ValueError(error_msg)
@lru_cache(maxsize=None) @cache
def get_watsonx_credentials() -> Dict[str, str]: def get_watsonx_credentials() -> Dict[str, str]:
""" """
Retrieves Watsonx API credentials from environmental variables. Retrieves Watsonx API credentials from environmental variables.
......
...@@ -40,7 +40,7 @@ try: ...@@ -40,7 +40,7 @@ try:
if parse_version(version("vllm")) >= parse_version("0.8.3"): if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template from vllm.entrypoints.chat_utils import resolve_hf_chat_template
except ModuleNotFoundError: except ModuleNotFoundError:
pass print("njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd")
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
......
...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor): ...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None self.indexes = None
class ACPGrammarParser(object): class ACPGrammarParser:
def __init__(self, task) -> None: def __init__(self, task) -> None:
self.task = task self.task = task
with open(GRAMMAR_FILE) as f: with open(GRAMMAR_FILE) as f:
...@@ -556,8 +556,8 @@ class STRIPS: ...@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret]) return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s): def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower()) d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower()) p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static()) new_state = get_atoms_pddl(d, p, s | self.get_static())
......
...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor): ...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None self.indexes = None
class ACPGrammarParser(object): class ACPGrammarParser:
def __init__(self, task) -> None: def __init__(self, task) -> None:
self.task = task self.task = task
with open(GRAMMAR_FILE) as f: with open(GRAMMAR_FILE) as f:
...@@ -556,8 +556,8 @@ class STRIPS: ...@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret]) return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s): def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower()) d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower()) p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static()) new_state = get_atoms_pddl(d, p, s | self.get_static())
......
...@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates ...@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from tqdm import tqdm from tqdm import tqdm
# from lm_eval.api.registry import ALL_TASKS # from lm_eval.api.registryv2 import ALL_TASKS
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment