Commit 4facd5c8 authored by Baber's avatar Baber
Browse files

type hints

parent 2ae642d8
......@@ -1049,7 +1049,9 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: dict, doc_to_target=None) -> Union[int, str, list]:
def doc_to_target(
self, doc: dict, doc_to_target=None
) -> Union[int, str, list[int]]:
# if self.prompt is not None:
# doc_to_target = self.prompt
if doc_to_target is not None:
......@@ -1096,7 +1098,9 @@ class ConfigurableTask(Task):
raise TypeError
def doc_to_choice(
self, doc: dict, doc_to_choice: Union[str, list, dict, None] = None
self,
doc: dict,
doc_to_choice: Union[str, list, dict, Callable[..., list[str]], None] = None,
) -> List[str]:
# if self.prompt is not None:
# doc_to_choice = self.prompt
......@@ -1119,8 +1123,8 @@ class ConfigurableTask(Task):
return list(doc_to_choice.values())
elif callable(doc_to_choice):
return doc_to_choice(doc)
elif hasattr(doc_to_choice, "get_answer_choices_list"):
return doc_to_choice.get_answer_choices_list(doc)
# elif hasattr(doc_to_choice, "get_answer_choices_list"):
# return doc_to_choice.get_answer_choices_list(doc)
else:
raise TypeError
......@@ -1329,6 +1333,8 @@ class ConfigurableTask(Task):
raise ValueError
# and this stores our "regular" conditional loglikelihoods
lls = lls[: len(choices)]
else:
lls_unconditional = None
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
......@@ -1386,6 +1392,9 @@ class ConfigurableTask(Task):
}
if "acc_mutual_info" in use_metric:
assert lls_unconditional is not None, (
"lls_unconditional should not be None if acc_mutual_info is in use_metric"
)
lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
]
......
......@@ -3,8 +3,8 @@ from typing import Any, Callable, Union
def serialize_callable(
value: Union[Callable, str], keep_callable=False
) -> Union[Callable, str]:
value: Union[Callable[..., Any], str], keep_callable=False
) -> Union[Callable[..., Any], str]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
......
from __future__ import annotations
import itertools
import json
import logging
......@@ -5,7 +7,7 @@ import os
import random
import time
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union
import numpy as np
import torch
......@@ -48,7 +50,7 @@ eval_logger = logging.getLogger(__name__)
@positional_deprecated
def simple_evaluate(
model,
model_args: Optional[Union[str, dict]] = None,
model_args: Optional[Union[str, dict[str, Any]]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None,
batch_size: Optional[Union[int, str]] = None,
......@@ -414,7 +416,7 @@ def simple_evaluate(
def evaluate(
lm: "LM",
task_dict,
limit: Optional[int] = None,
limit: int | float | None = None,
samples: Optional[dict] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
......
......@@ -105,7 +105,13 @@ plugins.md034.enabled = false # no-bare-urls
[tool.ruff.lint]
extend-select = ["I"]
[tool.ruff]
target-version = "py39"
extend-select = ["I", "UP", "E", "C419"]
ignore = ["E402", "E731"]
[tool.ruff.lint.isort]
combine-as-imports = true
lines-after-imports = 2
known-first-party = ["lm_eval"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment