Commit 4facd5c8 authored by Baber's avatar Baber
Browse files

type hints

parent 2ae642d8
...@@ -1049,7 +1049,9 @@ class ConfigurableTask(Task): ...@@ -1049,7 +1049,9 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: dict, doc_to_target=None) -> Union[int, str, list]: def doc_to_target(
self, doc: dict, doc_to_target=None
) -> Union[int, str, list[int]]:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_target = self.prompt # doc_to_target = self.prompt
if doc_to_target is not None: if doc_to_target is not None:
...@@ -1096,7 +1098,9 @@ class ConfigurableTask(Task): ...@@ -1096,7 +1098,9 @@ class ConfigurableTask(Task):
raise TypeError raise TypeError
def doc_to_choice( def doc_to_choice(
self, doc: dict, doc_to_choice: Union[str, list, dict, None] = None self,
doc: dict,
doc_to_choice: Union[str, list, dict, Callable[..., list[str]], None] = None,
) -> List[str]: ) -> List[str]:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_choice = self.prompt # doc_to_choice = self.prompt
...@@ -1119,8 +1123,8 @@ class ConfigurableTask(Task): ...@@ -1119,8 +1123,8 @@ class ConfigurableTask(Task):
return list(doc_to_choice.values()) return list(doc_to_choice.values())
elif callable(doc_to_choice): elif callable(doc_to_choice):
return doc_to_choice(doc) return doc_to_choice(doc)
elif hasattr(doc_to_choice, "get_answer_choices_list"): # elif hasattr(doc_to_choice, "get_answer_choices_list"):
return doc_to_choice.get_answer_choices_list(doc) # return doc_to_choice.get_answer_choices_list(doc)
else: else:
raise TypeError raise TypeError
...@@ -1329,6 +1333,8 @@ class ConfigurableTask(Task): ...@@ -1329,6 +1333,8 @@ class ConfigurableTask(Task):
raise ValueError raise ValueError
# and this stores our "regular" conditional loglikelihoods # and this stores our "regular" conditional loglikelihoods
lls = lls[: len(choices)] lls = lls[: len(choices)]
else:
lls_unconditional = None
pred = np.argmax(lls) pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len) pred_norm = np.argmax(lls / completion_len)
...@@ -1386,6 +1392,9 @@ class ConfigurableTask(Task): ...@@ -1386,6 +1392,9 @@ class ConfigurableTask(Task):
} }
if "acc_mutual_info" in use_metric: if "acc_mutual_info" in use_metric:
assert lls_unconditional is not None, (
"lls_unconditional should not be None if acc_mutual_info is in use_metric"
)
lls_mutual_info = [ lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional) ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
] ]
......
...@@ -3,8 +3,8 @@ from typing import Any, Callable, Union ...@@ -3,8 +3,8 @@ from typing import Any, Callable, Union
def serialize_callable( def serialize_callable(
value: Union[Callable, str], keep_callable=False value: Union[Callable[..., Any], str], keep_callable=False
) -> Union[Callable, str]: ) -> Union[Callable[..., Any], str]:
"""Serializes a given function or string. """Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned. If 'keep_callable' is True, the original callable is returned.
......
from __future__ import annotations
import itertools import itertools
import json import json
import logging import logging
...@@ -5,7 +7,7 @@ import os ...@@ -5,7 +7,7 @@ import os
import random import random
import time import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, Any, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -48,7 +50,7 @@ eval_logger = logging.getLogger(__name__) ...@@ -48,7 +50,7 @@ eval_logger = logging.getLogger(__name__)
@positional_deprecated @positional_deprecated
def simple_evaluate( def simple_evaluate(
model, model,
model_args: Optional[Union[str, dict]] = None, model_args: Optional[Union[str, dict[str, Any]]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None, tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None, num_fewshot: Optional[int] = None,
batch_size: Optional[Union[int, str]] = None, batch_size: Optional[Union[int, str]] = None,
...@@ -414,7 +416,7 @@ def simple_evaluate( ...@@ -414,7 +416,7 @@ def simple_evaluate(
def evaluate( def evaluate(
lm: "LM", lm: "LM",
task_dict, task_dict,
limit: Optional[int] = None, limit: int | float | None = None,
samples: Optional[dict] = None, samples: Optional[dict] = None,
cache_requests: bool = False, cache_requests: bool = False,
rewrite_requests_cache: bool = False, rewrite_requests_cache: bool = False,
......
...@@ -105,7 +105,13 @@ plugins.md034.enabled = false # no-bare-urls ...@@ -105,7 +105,13 @@ plugins.md034.enabled = false # no-bare-urls
[tool.ruff.lint] [tool.ruff.lint]
extend-select = ["I"] extend-select = ["I"]
[tool.ruff]
target-version = "py39"
extend-select = ["I", "UP", "E", "C419"]
ignore = ["E402", "E731"]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
combine-as-imports = true
lines-after-imports = 2 lines-after-imports = 2
known-first-party = ["lm_eval"] known-first-party = ["lm_eval"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment