Commit fc69d84f authored by Ethan Smith's avatar Ethan Smith
Browse files

Add suggestions from autotyping

This adds a bunch of simple annotations suggested by https://github.com/JelleZijlstra/autotyping.
parent da85f290
......@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
)
def register_configurable_task(config):
def register_configurable_task(config: dict[str, str]) -> int:
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
......@@ -38,7 +38,7 @@ def register_configurable_task(config):
return 0
def check_prompt_config(config):
def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]:
all_configs = []
if "use_prompt" in config:
prompt_list = prompts.load_prompt_list(
......@@ -69,14 +69,14 @@ def check_prompt_config(config):
return all_configs
def get_task_name_from_config(task_config):
def get_task_name_from_config(task_config: dict[str, str]) -> str:
if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config)
else:
return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir):
def include_task_folder(task_dir: str) -> None:
"""
Calling this function
"""
......
def doc_to_text(doc):
def doc_to_text(doc) -> str:
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
doc["premise"],
doc["hypothesis"].strip()
......
......@@ -15,7 +15,7 @@ def _preproc_doc(doc):
return doc
def doc_to_text(doc):
def doc_to_text(doc) -> str:
doc = _preproc_doc(doc)
return f"Scenario 1: {doc['scenarios'][0]}\nScenario 2: {doc['scenarios'][1]}\nQuestion: Is Scenario 1 preferable?\nAnswer:"
......
def doc_to_text(doc):
def doc_to_text(doc) -> str:
ctxs = "\n".join(doc["context"]["contexts"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format(
ctxs, doc["question"], doc["final_decision"]
)
def doc_to_target(doc):
def doc_to_target(doc) -> str:
return " {}".format(doc["final_decision"])
......
......@@ -10,7 +10,7 @@ import collections
import importlib.util
import fnmatch
from typing import List, Literal, Union
from typing import Iterator, List, Literal, Union
import gc
import torch
......@@ -65,7 +65,7 @@ def join_iters(iters):
yield from iter
def chunks(iter, n=0, fn=None):
def chunks(iter, n: int = 0, fn=None):
arr = []
for i, x in enumerate(iter):
arr.append(x)
......@@ -87,11 +87,11 @@ def group(arr, fn):
class MultiChoice:
def __init__(self, choices):
def __init__(self, choices) -> None:
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
def __contains__(self, values) -> bool:
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.info(f"Available tasks to choose:")
......@@ -100,7 +100,7 @@ class MultiChoice:
raise ValueError("'{}' is not in task list".format(value))
return True
def __iter__(self):
def __iter__(self) -> Iterator:
for choice in self.choices:
yield choice
......@@ -108,7 +108,6 @@ class MultiChoice:
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
if type(patterns) == str:
patterns = [patterns]
......@@ -177,7 +176,7 @@ def make_disjoint_window(pair):
class Reorderer:
def __init__(self, arr, fn):
def __init__(self, arr, fn) -> None:
self.size = len(arr)
arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1]))
......@@ -212,7 +211,7 @@ class Grouper:
objects in `arr` satisfying `key == fn(ob)`.
"""
def __init__(self, arr, fn):
def __init__(self, arr, fn) -> None:
# self.orig_arr = arr
self.size = len(arr)
arr = list(enumerate(arr))
......@@ -263,7 +262,7 @@ class Grouper:
return res
def make_table(result_dict, column="results"):
def make_table(result_dict, column: str = "results"):
"""Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter
......@@ -393,7 +392,6 @@ def get_git_commit_hash():
def import_function(loader, node):
function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name)
......@@ -428,7 +426,6 @@ def load_yaml_config(yaml_path):
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
......@@ -447,7 +444,7 @@ def load_yaml_config(yaml_path):
return yaml_config
def regex_replace(string, pattern, repl, count=0):
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
......@@ -521,7 +518,7 @@ def pad_and_concat(
return torch.cat(tensors, dim=0)
def clear_torch_cache():
def clear_torch_cache() -> None:
gc.collect()
torch.cuda.empty_cache()
......@@ -546,7 +543,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
) -> None:
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
......
......@@ -15,7 +15,7 @@ from lm_eval.tasks import include_task_folder
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def parse_args():
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
parser.add_argument(
......@@ -98,7 +98,7 @@ def parse_args():
return parser.parse_args()
def main():
def main() -> None:
args = parse_args()
if args.limit:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment