Commit 93b2ab37 authored by Baber's avatar Baber
Browse files

refactor registry

parent de496b80
......@@ -4,8 +4,8 @@ import os
import random
import re
import string
from collections.abc import Iterable
from typing import Callable, List, Optional, Sequence, TypeVar
from collections.abc import Iterable, Sequence
from typing import Callable, List, Optional, TypeVar
import numpy as np
import sacrebleu
......
This diff is collapsed.
......@@ -3,18 +3,15 @@ import ast
import logging
import random
import re
from collections.abc import Callable
from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Union,
......@@ -1774,7 +1771,7 @@ class MultipleChoiceTask(Task):
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice)),
arguments=(ctx, f" {choice}"),
idx=i,
**kwargs,
)
......
from . import (
anthropic_llms,
api_models,
dummy,
gguf,
hf_audiolm,
hf_steered,
hf_vlms,
huggingface,
ibm_watsonx_ai,
mamba_lm,
nemo_lm,
neuron_optimum,
openai_completions,
optimum_ipex,
optimum_lm,
sglang_causallms,
sglang_generate_API,
textsynth,
vllm_causallms,
vllm_vlms,
)
# TODO: implement __all__
# Models are now lazily loaded via the registry system
# No need to import them all at once
# Define model mappings for lazy registration
MODEL_MAPPING = {
"anthropic-completions": "lm_eval.models.anthropic_llms:AnthropicLM",
"anthropic-chat": "lm_eval.models.anthropic_llms:AnthropicChatLM",
"anthropic-chat-completions": "lm_eval.models.anthropic_llms:AnthropicCompletionsLM",
"local-completions": "lm_eval.models.openai_completions:LocalCompletionsAPI",
"local-chat-completions": "lm_eval.models.openai_completions:LocalChatCompletion",
"openai-completions": "lm_eval.models.openai_completions:OpenAICompletionsAPI",
"openai-chat-completions": "lm_eval.models.openai_completions:OpenAIChatCompletion",
"dummy": "lm_eval.models.dummy:DummyLM",
"gguf": "lm_eval.models.gguf:GGUFLM",
"ggml": "lm_eval.models.gguf:GGUFLM",
"hf-audiolm-qwen": "lm_eval.models.hf_audiolm:HFAudioLM",
"steered": "lm_eval.models.hf_steered:SteeredHF",
"hf-multimodal": "lm_eval.models.hf_vlms:HFMultimodalLM",
"hf-auto": "lm_eval.models.huggingface:HFLM",
"hf": "lm_eval.models.huggingface:HFLM",
"huggingface": "lm_eval.models.huggingface:HFLM",
"watsonx_llm": "lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI",
"mamba_ssm": "lm_eval.models.mamba_lm:MambaLMWrapper",
"nemo_lm": "lm_eval.models.nemo_lm:NeMoLM",
"neuronx": "lm_eval.models.neuron_optimum:NeuronModelForCausalLM",
"ipex": "lm_eval.models.optimum_ipex:IPEXForCausalLM",
"openvino": "lm_eval.models.optimum_lm:OptimumLM",
"sglang": "lm_eval.models.sglang_causallms:SGLANG",
"sglang-generate": "lm_eval.models.sglang_generate_API:SGAPI",
"textsynth": "lm_eval.models.textsynth:TextSynthLM",
"vllm": "lm_eval.models.vllm_causallms:VLLM",
"vllm-vlm": "lm_eval.models.vllm_vlms:VLLM_VLM",
}
# Register all models lazily
def _register_all_models():
"""Register all known models lazily in the registry."""
from lm_eval.api.registry import model_registry
for name, path in MODEL_MAPPING.items():
# Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry:
# Call register with the lazy parameter, returns a decorator
model_registry.register(name, lazy=path)(None)
# Call registration on module import
_register_all_models()
__all__ = ["MODEL_MAPPING"]
try:
......
from collections.abc import Generator
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Any, Callable, Generator, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
from peft.peft_model import PeftModel
......
......@@ -3,7 +3,7 @@ import json
import logging
import os
import warnings
from functools import lru_cache
from functools import cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm
......@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise ValueError(error_msg)
@lru_cache(maxsize=None)
@cache
def get_watsonx_credentials() -> Dict[str, str]:
"""
Retrieves Watsonx API credentials from environmental variables.
......
......@@ -40,7 +40,7 @@ try:
if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template
except ModuleNotFoundError:
pass
print("njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd")
if TYPE_CHECKING:
pass
......
......@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None
class ACPGrammarParser(object):
class ACPGrammarParser:
def __init__(self, task) -> None:
self.task = task
with open(GRAMMAR_FILE) as f:
......@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower())
d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static())
......
......@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None
class ACPGrammarParser(object):
class ACPGrammarParser:
def __init__(self, task) -> None:
self.task = task
with open(GRAMMAR_FILE) as f:
......@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower())
d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static())
......
......@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from tqdm import tqdm
# from lm_eval.api.registry import ALL_TASKS
# from lm_eval.api.registryv2 import ALL_TASKS
eval_logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment