Commit 97f5c020 authored by baberabb's avatar baberabb
Browse files

added typehints

parent b0f67f2c
...@@ -13,7 +13,7 @@ from tqdm import tqdm ...@@ -13,7 +13,7 @@ from tqdm import tqdm
import datasets import datasets
import numpy as np import numpy as np
from typing import Union from typing import Union, List, Any, Tuple, Literal
from collections.abc import Callable from collections.abc import Callable
from lm_eval import utils from lm_eval import utils
...@@ -477,7 +477,7 @@ class Task(abc.ABC): ...@@ -477,7 +477,7 @@ class Task(abc.ABC):
eval_logger.warning("No filter defined, passing through instances") eval_logger.warning("No filter defined, passing through instances")
return self._instances return self._instances
def dump_config(self): def dump_config(self) -> dict:
"""Returns a dictionary representing the task's config. """Returns a dictionary representing the task's config.
:returns: str :returns: str
...@@ -489,14 +489,13 @@ class Task(abc.ABC): ...@@ -489,14 +489,13 @@ class Task(abc.ABC):
class ConfigurableTask(Task): class ConfigurableTask(Task):
VERSION = "Yaml" VERSION = "Yaml"
OUTPUT_TYPE = None OUTPUT_TYPE = None
CONFIG = None CONFIG = None
def __init__( def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
): ): # TODO no super() call here
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -662,25 +661,25 @@ class ConfigurableTask(Task): ...@@ -662,25 +661,25 @@ class ConfigurableTask(Task):
**dataset_kwargs if dataset_kwargs is not None else {}, **dataset_kwargs if dataset_kwargs is not None else {},
) )
def has_training_docs(self): def has_training_docs(self) -> bool:
if self._config.training_split is not None: if self._config.training_split is not None:
return True return True
else: else:
return False return False
def has_validation_docs(self): def has_validation_docs(self) -> bool:
if self._config.validation_split is not None: if self._config.validation_split is not None:
return True return True
else: else:
return False return False
def has_test_docs(self): def has_test_docs(self) -> bool:
if self._config.test_split is not None: if self._config.test_split is not None:
return True return True
else: else:
return False return False
def training_docs(self): def training_docs(self) -> datasets.Dataset:
if self.has_training_docs(): if self.has_training_docs():
if self._config.process_docs is not None: if self._config.process_docs is not None:
return self._config.process_docs( return self._config.process_docs(
...@@ -688,7 +687,7 @@ class ConfigurableTask(Task): ...@@ -688,7 +687,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self._config.training_split] return self.dataset[self._config.training_split]
def validation_docs(self): def validation_docs(self) -> datasets.Dataset:
if self.has_validation_docs(): if self.has_validation_docs():
if self._config.process_docs is not None: if self._config.process_docs is not None:
return self._config.process_docs( return self._config.process_docs(
...@@ -696,7 +695,7 @@ class ConfigurableTask(Task): ...@@ -696,7 +695,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self._config.validation_split] return self.dataset[self._config.validation_split]
def test_docs(self): def test_docs(self) -> datasets.Dataset:
if self.has_test_docs(): if self.has_test_docs():
if self._config.process_docs is not None: if self._config.process_docs is not None:
return self._config.process_docs(self.dataset[self._config.test_split]) return self._config.process_docs(self.dataset[self._config.test_split])
...@@ -767,7 +766,7 @@ class ConfigurableTask(Task): ...@@ -767,7 +766,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc): def doc_to_target(self, doc: dict) -> Union[int, str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
...@@ -796,7 +795,7 @@ class ConfigurableTask(Task): ...@@ -796,7 +795,7 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_choice(self, doc): def doc_to_choice(self, doc: Any) -> List[str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_choice = self.prompt doc_to_choice = self.prompt
...@@ -838,7 +837,9 @@ class ConfigurableTask(Task): ...@@ -838,7 +837,9 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc)) arguments = (ctx, self.doc_to_target(doc))
...@@ -1037,13 +1038,12 @@ class ConfigurableTask(Task): ...@@ -1037,13 +1038,12 @@ class ConfigurableTask(Task):
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood" OUTPUT_TYPE: str = "loglikelihood"
def doc_to_target(self, doc): def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]] return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
# TODO: add mutual info here? # TODO: add mutual info here?
return [ return [
Instance( Instance(
...@@ -1056,7 +1056,7 @@ class MultipleChoiceTask(Task): ...@@ -1056,7 +1056,7 @@ class MultipleChoiceTask(Task):
for i, choice in enumerate(doc["choices"]) for i, choice in enumerate(doc["choices"])
] ]
def process_results(self, doc, results): def process_results(self, doc: dict, results: List[Tuple[float, bool]]) -> dict:
results = [ results = [
res[0] for res in results res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere? ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
...@@ -1071,13 +1071,13 @@ class MultipleChoiceTask(Task): ...@@ -1071,13 +1071,13 @@ class MultipleChoiceTask(Task):
"acc_norm": acc_norm, "acc_norm": acc_norm,
} }
def higher_is_better(self): def higher_is_better(self) -> dict:
return { return {
"acc": True, "acc": True,
"acc_norm": True, "acc_norm": True,
} }
def aggregation(self): def aggregation(self) -> dict:
return { return {
"acc": mean, "acc": mean,
"acc_norm": mean, "acc_norm": mean,
...@@ -1085,24 +1085,23 @@ class MultipleChoiceTask(Task): ...@@ -1085,24 +1085,23 @@ class MultipleChoiceTask(Task):
class PerplexityTask(Task): class PerplexityTask(Task):
OUTPUT_TYPE = "loglikelihood_rolling" OUTPUT_TYPE = "loglikelihood_rolling"
def has_training_docs(self): def has_training_docs(self) -> bool:
return False return False
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k: int, rnd) -> List:
assert k == 0 assert k == 0
return [] return []
def fewshot_context(self, doc, num_fewshot): def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
assert ( assert (
num_fewshot == 0 num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks." ), "The number of fewshot examples must be 0 for perplexity tasks."
return "" return ""
def higher_is_better(self): def higher_is_better(self) -> dict:
return { return {
"word_perplexity": False, "word_perplexity": False,
"byte_perplexity": False, "byte_perplexity": False,
...@@ -1118,7 +1117,7 @@ class PerplexityTask(Task): ...@@ -1118,7 +1117,7 @@ class PerplexityTask(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc return doc
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc: dict, ctx: Union[str, None], **kwargs):
assert not ctx assert not ctx
return Instance( return Instance(
...@@ -1129,7 +1128,7 @@ class PerplexityTask(Task): ...@@ -1129,7 +1128,7 @@ class PerplexityTask(Task):
**kwargs, **kwargs,
) )
def process_results(self, doc, results): def process_results(self, doc: dict, results: float) -> dict:
(loglikelihood,) = results (loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc)) words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc)) bytes_ = self.count_bytes(self.doc_to_target(doc))
...@@ -1139,7 +1138,7 @@ class PerplexityTask(Task): ...@@ -1139,7 +1138,7 @@ class PerplexityTask(Task):
"bits_per_byte": (loglikelihood, bytes_), "bits_per_byte": (loglikelihood, bytes_),
} }
def aggregation(self): def aggregation(self) -> dict:
return { return {
"word_perplexity": weighted_perplexity, "word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity, "byte_perplexity": weighted_perplexity,
...@@ -1147,10 +1146,10 @@ class PerplexityTask(Task): ...@@ -1147,10 +1146,10 @@ class PerplexityTask(Task):
} }
@classmethod @classmethod
def count_bytes(cls, doc): def count_bytes(cls, doc) -> int:
return len(doc.encode("utf-8")) return len(doc.encode("utf-8"))
@classmethod @classmethod
def count_words(cls, doc): def count_words(cls, doc) -> int:
"""Downstream tasks with custom word boundaries should override this!""" """Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
...@@ -456,7 +456,7 @@ env = Environment(loader=BaseLoader, undefined=StrictUndefined) ...@@ -456,7 +456,7 @@ env = Environment(loader=BaseLoader, undefined=StrictUndefined)
env.filters["regex_replace"] = regex_replace env.filters["regex_replace"] = regex_replace
def apply_template(template, doc): def apply_template(template: str, doc: dict) -> str:
rtemplate = env.from_string(template) rtemplate = env.from_string(template)
return rtemplate.render(**doc) return rtemplate.render(**doc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment