Commit 424a4280 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed register_model origin and other imports

parent ddb7c0f3
...@@ -118,7 +118,7 @@ def matthews_corrcoef(items): ...@@ -118,7 +118,7 @@ def matthews_corrcoef(items):
@register_metric( @register_metric(
metric="f1_score", metric="f1",
higher_is_better=True, higher_is_better=True,
output_type="multiple_choice", output_type="multiple_choice",
aggregation="mean", aggregation="mean",
......
...@@ -49,9 +49,9 @@ def register_task(name): ...@@ -49,9 +49,9 @@ def register_task(name):
def register_group(name): def register_group(name):
def decorate(fn): def decorate(fn):
assert ( # assert (
name not in GROUP_REGISTRY # name not in GROUP_REGISTRY
), f"group named '{name}' conflicts with existing registered group!" # ), f"group named '{name}' conflicts with existing registered group!"
func_name = func2task_index[fn.__name__] func_name = func2task_index[fn.__name__]
if name in GROUP_REGISTRY: if name in GROUP_REGISTRY:
......
import random import random
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
@register_model("dummy") @register_model("dummy")
......
...@@ -7,7 +7,8 @@ import torch.nn.functional as F ...@@ -7,7 +7,8 @@ import torch.nn.functional as F
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from accelerate import Accelerator from accelerate import Accelerator
from itertools import islice from itertools import islice
......
...@@ -6,7 +6,8 @@ import numpy as np ...@@ -6,7 +6,8 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
def get_result(response, ctxlen): def get_result(response, ctxlen):
......
...@@ -16,7 +16,8 @@ import os ...@@ -16,7 +16,8 @@ import os
import requests as _requests import requests as _requests
import time import time
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment