Commit 97f5c020 authored by baberabb's avatar baberabb
Browse files

added typehints

parent b0f67f2c
......@@ -13,7 +13,7 @@ from tqdm import tqdm
import datasets
import numpy as np
from typing import Union
from typing import Union, List, Any, Tuple, Literal
from collections.abc import Callable
from lm_eval import utils
......@@ -477,7 +477,7 @@ class Task(abc.ABC):
eval_logger.warning("No filter defined, passing through instances")
return self._instances
def dump_config(self):
def dump_config(self) -> dict:
"""Returns a dictionary representing the task's config.
:returns: str
......@@ -489,14 +489,13 @@ class Task(abc.ABC):
class ConfigurableTask(Task):
VERSION = "Yaml"
OUTPUT_TYPE = None
CONFIG = None
def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
):
): # TODO no super() call here
# Get pre-configured attributes
self._config = self.CONFIG
......@@ -662,25 +661,25 @@ class ConfigurableTask(Task):
**dataset_kwargs if dataset_kwargs is not None else {},
)
def has_training_docs(self):
def has_training_docs(self) -> bool:
if self._config.training_split is not None:
return True
else:
return False
def has_validation_docs(self):
def has_validation_docs(self) -> bool:
if self._config.validation_split is not None:
return True
else:
return False
def has_test_docs(self):
def has_test_docs(self) -> bool:
if self._config.test_split is not None:
return True
else:
return False
def training_docs(self):
def training_docs(self) -> datasets.Dataset:
if self.has_training_docs():
if self._config.process_docs is not None:
return self._config.process_docs(
......@@ -688,7 +687,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self._config.training_split]
def validation_docs(self):
def validation_docs(self) -> datasets.Dataset:
if self.has_validation_docs():
if self._config.process_docs is not None:
return self._config.process_docs(
......@@ -696,7 +695,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self._config.validation_split]
def test_docs(self):
def test_docs(self) -> datasets.Dataset:
if self.has_test_docs():
if self._config.process_docs is not None:
return self._config.process_docs(self.dataset[self._config.test_split])
......@@ -767,7 +766,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc):
def doc_to_target(self, doc: dict) -> Union[int, str]:
if self.prompt is not None:
doc_to_target = self.prompt
......@@ -796,7 +795,7 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_choice(self, doc):
def doc_to_choice(self, doc: Any) -> List[str]:
if self.prompt is not None:
doc_to_choice = self.prompt
......@@ -838,7 +837,9 @@ class ConfigurableTask(Task):
else:
raise TypeError
def construct_requests(self, doc, ctx, **kwargs):
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc))
......@@ -1037,13 +1038,12 @@ class ConfigurableTask(Task):
class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood"
def doc_to_target(self, doc):
def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc, ctx, **kwargs):
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
# TODO: add mutual info here?
return [
Instance(
......@@ -1056,7 +1056,7 @@ class MultipleChoiceTask(Task):
for i, choice in enumerate(doc["choices"])
]
def process_results(self, doc, results):
def process_results(self, doc: dict, results: List[Tuple[float, bool]]) -> dict:
results = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
......@@ -1071,13 +1071,13 @@ class MultipleChoiceTask(Task):
"acc_norm": acc_norm,
}
def higher_is_better(self):
def higher_is_better(self) -> dict:
return {
"acc": True,
"acc_norm": True,
}
def aggregation(self):
def aggregation(self) -> dict:
return {
"acc": mean,
"acc_norm": mean,
......@@ -1085,24 +1085,23 @@ class MultipleChoiceTask(Task):
class PerplexityTask(Task):
OUTPUT_TYPE = "loglikelihood_rolling"
def has_training_docs(self):
def has_training_docs(self) -> bool:
return False
def fewshot_examples(self, k, rnd):
def fewshot_examples(self, k: int, rnd) -> List:
assert k == 0
return []
def fewshot_context(self, doc, num_fewshot):
def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
assert (
num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks."
return ""
def higher_is_better(self):
def higher_is_better(self) -> dict:
return {
"word_perplexity": False,
"byte_perplexity": False,
......@@ -1118,7 +1117,7 @@ class PerplexityTask(Task):
def doc_to_target(self, doc):
return doc
def construct_requests(self, doc, ctx, **kwargs):
def construct_requests(self, doc: dict, ctx: Union[str, None], **kwargs):
assert not ctx
return Instance(
......@@ -1129,7 +1128,7 @@ class PerplexityTask(Task):
**kwargs,
)
def process_results(self, doc, results):
def process_results(self, doc: dict, results: float) -> dict:
(loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc))
......@@ -1139,7 +1138,7 @@ class PerplexityTask(Task):
"bits_per_byte": (loglikelihood, bytes_),
}
def aggregation(self):
def aggregation(self) -> dict:
return {
"word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity,
......@@ -1147,10 +1146,10 @@ class PerplexityTask(Task):
}
@classmethod
def count_bytes(cls, doc):
def count_bytes(cls, doc) -> int:
return len(doc.encode("utf-8"))
@classmethod
def count_words(cls, doc):
def count_words(cls, doc) -> int:
"""Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
......@@ -456,7 +456,7 @@ env = Environment(loader=BaseLoader, undefined=StrictUndefined)
env.filters["regex_replace"] = regex_replace
def apply_template(template, doc):
def apply_template(template: str, doc: dict) -> str:
rtemplate = env.from_string(template)
return rtemplate.render(**doc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment